diff --git a/.ai/claude.prompt.md b/.ai/claude.prompt.md new file mode 100644 index 000000000..7f38f5752 --- /dev/null +++ b/.ai/claude.prompt.md @@ -0,0 +1,9 @@ +## About This File + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## 1. Project Context +Here is the essential context for our project. Please read and understand it thoroughly. + +### Project Overview +@./context/01-overview.md diff --git a/.ai/context/01-overview.md b/.ai/context/01-overview.md new file mode 100644 index 000000000..41133e983 --- /dev/null +++ b/.ai/context/01-overview.md @@ -0,0 +1,101 @@ +This file provides the overview and guidance for developers working with the codebase, including setup instructions, architecture details, and common commands. + +## Project Architecture + +### Core Training Framework +The codebase is built around a **strategy pattern architecture** that supports multiple diffusion model families: + +- **`library/strategy_base.py`**: Base classes for tokenization, text encoding, latent caching, and training strategies +- **`library/strategy_*.py`**: Model-specific implementations for SD, SDXL, SD3, FLUX, etc. +- **`library/train_util.py`**: Core training utilities shared across all model types +- **`library/config_util.py`**: Configuration management with TOML support + +### Model Support Structure +Each supported model family has a consistent structure: +- **Training script**: `{model}_train.py` (full fine-tuning), `{model}_train_network.py` (LoRA/network training) +- **Model utilities**: `library/{model}_models.py`, `library/{model}_train_utils.py`, `library/{model}_utils.py` +- **Networks**: `networks/lora_{model}.py`, `networks/oft_{model}.py` for adapter training + +### Supported Models +- **Stable Diffusion 1.x**: `train*.py`, `library/train_util.py`, `train_db.py` (for DreamBooth) +- **SDXL**: `sdxl_train*.py`, `library/sdxl_*` +- **SD3**: `sd3_train*.py`, `library/sd3_*` +- **FLUX.1**: `flux_train*.py`, `library/flux_*` + +### Key Components + +#### Memory Management +- **Block swapping**: CPU-GPU memory optimization via `--blocks_to_swap` parameter, works with custom offloading. Only available for models with transformer architectures like SD3 and FLUX.1. +- **Custom offloading**: `library/custom_offloading_utils.py` for advanced memory management +- **Gradient checkpointing**: Memory reduction during training + +#### Training Features +- **LoRA training**: Low-rank adaptation networks in `networks/lora*.py` +- **ControlNet training**: Conditional generation control +- **Textual Inversion**: Custom embedding training +- **Multi-resolution training**: Bucket-based aspect ratio handling +- **Validation loss**: Real-time training monitoring, only for LoRA training + +#### Configuration System +Dataset configuration uses TOML files with structured validation: +```toml +[datasets.sample_dataset] + resolution = 1024 + batch_size = 2 + + [[datasets.sample_dataset.subsets]] + image_dir = "path/to/images" + caption_extension = ".txt" +``` + +## Common Development Commands + +### Training Commands Pattern +All training scripts follow this general pattern: +```bash +accelerate launch --mixed_precision bf16 {script_name}.py \ + --pretrained_model_name_or_path model.safetensors \ + --dataset_config config.toml \ + --output_dir output \ + --output_name model_name \ + [model-specific options] +``` + +### Memory Optimization +For low VRAM environments, use block swapping: +```bash +# Add to any training command for memory reduction +--blocks_to_swap 10 # Swap 10 blocks to CPU (adjust number as needed) +``` + +### Utility Scripts +Located in `tools/` directory: +- `tools/merge_lora.py`: Merge LoRA weights into base models +- `tools/cache_latents.py`: Pre-cache VAE latents for faster training +- `tools/cache_text_encoder_outputs.py`: Pre-cache text encoder outputs + +## Development Notes + +### Strategy Pattern Implementation +When adding support for new models, implement the four core strategies: +1. `TokenizeStrategy`: Text tokenization handling +2. `TextEncodingStrategy`: Text encoder forward pass +3. `LatentsCachingStrategy`: VAE encoding/caching +4. `TextEncoderOutputsCachingStrategy`: Text encoder output caching + +### Testing Approach +- Unit tests focus on utility functions and model loading +- Integration tests validate training script syntax and basic execution +- Most tests use mocks to avoid requiring actual model files +- Add tests for new model support in `tests/test_{model}_*.py` + +### Configuration System +- Use `config_util.py` dataclasses for type-safe configuration +- Support both command-line arguments and TOML file configuration +- Validate configuration early in training scripts to prevent runtime errors + +### Memory Management +- Always consider VRAM limitations when implementing features +- Use gradient checkpointing for large models +- Implement block swapping for models with transformer architectures +- Cache intermediate results (latents, text embeddings) when possible \ No newline at end of file diff --git a/.ai/gemini.prompt.md b/.ai/gemini.prompt.md new file mode 100644 index 000000000..6047390bc --- /dev/null +++ b/.ai/gemini.prompt.md @@ -0,0 +1,9 @@ +## About This File + +This file provides guidance to Gemini CLI (https://github.com/google-gemini/gemini-cli) when working with code in this repository. + +## 1. Project Context +Here is the essential context for our project. Please read and understand it thoroughly. + +### Project Overview +@./context/01-overview.md diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..d35fe3925 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,51 @@ +name: Test with pytest + +on: + push: + branches: + - main + - dev + - sd3 + pull_request: + branches: + - main + - dev + - sd3 + +# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all" +permissions: read-all + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10"] # Python versions to test + pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test + + steps: + - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install and update pip, setuptools, wheel + run: | + # Setuptools, wheel for compiling some packages + python -m pip install --upgrade pip setuptools wheel + + - name: Install dependencies + run: | + # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 + pip install -r requirements.txt + + - name: Test with pytest + run: pytest # See pytest.ini for configuration + diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 0149dcdd3..b9d6acc98 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,21 +1,29 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened - synchronize - reopened +# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all" +permissions: read-all + jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false - name: typos-action - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.28.1 diff --git a/.gitignore b/.gitignore index d48110130..cfdc02685 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,4 @@ CLAUDE.md GEMINI.md .claude .gemini -MagicMock \ No newline at end of file +MagicMock diff --git a/README-ja.md b/README-ja.md index 71c3b0d54..27e15aa94 100644 --- a/README-ja.md +++ b/README-ja.md @@ -167,11 +167,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. + * `--n` ネガティブプロンプト(次のオプションまで) + * `--w` 生成画像の幅を指定 + * `--h` 生成画像の高さを指定 + * `--d` 生成画像のシード値を指定 + * `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください + * `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください + * `--s` 生成時のステップ数を指定 `( )` や `[ ]` などの重みづけも動作します。 diff --git a/README.md b/README.md index 629f1d415..c70dc257d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,81 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## FLUX.1 and SD3 training (WIP) + +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__ + +The command to install PyTorch is as follows: +`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124` + +For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version. + +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet). + +### Recent Updates + +Sep 23, 2025: +- HunyuanImage-2.1 LoRA training is supported. [PR #2198](https://github.com/kohya-ss/sd-scripts/pull/2198) for details. + - Please see [HunyuanImage-2.1 Training](./docs/hunyuan_image_train_network.md) for details. + - __HunyuanImage-2.1 training does not support LoRA modules for Text Encoders, so `--network_train_unet_only` is required.__ + - The training script is `hunyuan_image_train_network.py`. + - This includes changes to `train_network.py`, the base of the training script. Please let us know if you encounter any issues. + +Sep 13, 2025: +- The loading speed of `.safetensors` files has been improved for SD3, FLUX.1 and Lumina. See [PR #2200](https://github.com/kohya-ss/sd-scripts/pull/2200) for more details. + - Model loading can be up to 1.5 times faster. + - This is a wide-ranging update, so there may be bugs. Please let us know if you encounter any issues. + +Sep 4, 2025: +- The information about FLUX.1 and SD3/SD3.5 training that was described in the README has been organized and divided into the following documents: + - [LoRA Training Overview](./docs/train_network.md) + - [SDXL Training](./docs/sdxl_train_network.md) + - [Advanced Training](./docs/train_network_advanced.md) + - [FLUX.1 Training](./docs/flux_train_network.md) + - [SD3 Training](./docs/sd3_train_network.md) + - [LUMINA Training](./docs/lumina_train_network.md) + - [Validation](./docs/validation.md) + - [Fine-tuning](./docs/fine_tune.md) + - [Textual Inversion Training](./docs/train_textual_inversion.md) + +Aug 28, 2025: +- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues. +- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later. +- The `requirements.txt` has been updated, so please update your dependencies. + - You can update the dependencies with `pip install -r requirements.txt`. + - The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`. +- We have modified each script to minimize warnings as much as possible. + - The modified scripts will work in the old environment (library versions), but please update them when convenient. + + +## For Developers Using AI Coding Agents + +This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards. + +To use them, you need to opt-in by creating your own configuration file in the project root. + +**Quick Setup:** + +1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root. +2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt: + + ```markdown + @./.ai/claude.prompt.md + ``` + + or for Gemini: + + ```markdown + @./.ai/gemini.prompt.md + ``` + +3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`). + +This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so it won't be committed to the repository. + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 @@ -125,6 +201,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) +## DeepSpeed installation (experimental, Linux or WSL2 only) + +To install DeepSpeed, run the following command in your activated virtual environment: + +```bash +pip install deepspeed==0.16.7 +``` + ## Upgrade When a new release comes out you can upgrade your repo with the following command: @@ -226,7 +310,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. @@ -236,7 +320,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. - - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! @@ -295,7 +379,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 @@ -599,11 +683,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - * `--n` Negative prompt up to the next option. + * `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`. * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. + * `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG. + * `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. diff --git a/_typos.toml b/_typos.toml index bbf7728f4..686da4af2 100644 --- a/_typos.toml +++ b/_typos.toml @@ -29,7 +29,9 @@ koo="koo" yos="yos" wn="wn" hime="hime" - +OT="OT" +byt="byt" +tak="tak" [files] extend-exclude = ["_typos.toml", "venv"] diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 66a50dc09..78687ee6c 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -1,9 +1,6 @@ -Original Source by kohya-ss +First version: A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150 -First version: -A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150 - -Some parts are manually added. +Document is updated and maintained manually. # Config Readme @@ -152,6 +149,7 @@ These options are related to subset configuration. | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `resize_interpolation` | (not specified) | o | o | o | * `num_repeats` * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. @@ -165,6 +163,8 @@ These options are related to subset configuration. * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. * `enable_wildcard` * Enables wildcard notation. This will be explained later. +* `resize_interpolation` + * Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used. ### DreamBooth-specific options @@ -264,10 +264,10 @@ The following command line argument options are ignored if a configuration file * `--reg_data_dir` * `--in_json` -The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file. +For the command line options listed below, if an option is specified in both the command line arguments and the configuration file, the value from the configuration file will be given priority. Unless otherwise noted, the option names are the same. -| Command Line Argument Option | Prioritized Configuration File Option | -| ------------------------------- | ------------------------------------- | +| Command Line Argument Option | Corresponding Configuration File Option | +| ------------------------------- | --------------------------------------- | | `--bucket_no_upscale` | | | `--bucket_reso_steps` | | | `--caption_dropout_every_n_epochs` | | diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 0ed95e0eb..aec0eca5d 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `keep_tokens_separator` | `“|||”` | o | o | o | | `secondary_separator` | `“;;;”` | o | o | o | | `enable_wildcard` | `true` | o | o | o | +| `resize_interpolation` |(通常は設定しません) | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 @@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `enable_wildcard` * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 +* `resize_interpolation` + * 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 diff --git a/docs/fine_tune.md b/docs/fine_tune.md new file mode 100644 index 000000000..1560fb28a --- /dev/null +++ b/docs/fine_tune.md @@ -0,0 +1,347 @@ +# Fine-tuning Guide + +This document explains how to perform fine-tuning on various model architectures using the `*_train.py` scripts. + +
+日本語 + +# Fine-tuning ガイド + +このドキュメントでは、`*_train.py` スクリプトを用いた、各種モデルアーキテクチャのFine-tuningの方法について解説します。 + +
+ +### Difference between Fine-tuning and LoRA tuning + +This repository supports two methods for additional model training: **Fine-tuning** and **LoRA (Low-Rank Adaptation)**. Each method has distinct features and advantages. + +**Fine-tuning** is a method that retrains all (or most) of the weights of a pre-trained model. +- **Pros**: It can improve the overall expressive power of the model and is suitable for learning styles or concepts that differ significantly from the original model. +- **Cons**: + - It requires a large amount of VRAM and computational cost. + - The saved file size is large (same as the original model). + - It is prone to "overfitting," where the model loses the diversity of the original model if over-trained. +- **Corresponding scripts**: Scripts named `*_train.py`, such as `sdxl_train.py`, `sd3_train.py`, `flux_train.py`, and `lumina_train.py`. + +**LoRA tuning** is a method that freezes the model's weights and only trains a small additional network called an "adapter." +- **Pros**: + - It allows for fast training with low VRAM and computational cost. + - It is considered resistant to overfitting because it trains fewer weights. + - The saved file (LoRA network) is very small, ranging from tens to hundreds of MB, making it easy to manage. + - Multiple LoRAs can be used in combination. +- **Cons**: Since it does not train the entire model, it may not achieve changes as significant as fine-tuning. +- **Corresponding scripts**: Scripts named `*_train_network.py`, such as `sdxl_train_network.py`, `sd3_train_network.py`, and `flux_train_network.py`. + +| Feature | Fine-tuning | LoRA tuning | +|:---|:---|:---| +| **Training Target** | All model weights | Additional network (adapter) only | +| **VRAM/Compute Cost**| High | Low | +| **Training Time** | Long | Short | +| **File Size** | Large (several GB) | Small (few MB to hundreds of MB) | +| **Overfitting Risk** | High | Low | +| **Suitable Use Case** | Major style changes, concept learning | Adding specific characters or styles | + +Generally, it is recommended to start with **LoRA tuning** if you want to add a specific character or style. **Fine-tuning** is a valid option for more fundamental style changes or aiming for a high-quality model. + +
+日本語 + +### Fine-tuningとLoRA学習の違い + +このリポジトリでは、モデルの追加学習手法として**Fine-tuning**と**LoRA (Low-Rank Adaptation)**学習の2種類をサポートしています。それぞれの手法には異なる特徴と利点があります。 + +**Fine-tuning**は、事前学習済みモデルの重み全体(または大部分)を再学習する手法です。 +- **利点**: モデル全体の表現力を向上させることができ、元のモデルから大きく変化した画風やコンセプトの学習に適しています。 +- **欠点**: + - 学習には多くのVRAMと計算コストが必要です。 + - 保存されるファイルサイズが大きくなります(元のモデルと同じサイズ)。 + - 学習させすぎると、元のモデルが持っていた多様性が失われる「過学習(overfitting)」に陥りやすい傾向があります。 +- **対応スクリプト**: `sdxl_train.py`, `sd3_train.py`, `flux_train.py`, `lumina_train.py` など、`*_train.py` という命名規則のスクリプトが対応します。 + +**LoRA学習**は、モデルの重みは凍結(固定)したまま、「アダプター」と呼ばれる小さな追加ネットワークのみを学習する手法です。 +- **利点**: + - 少ないVRAMと計算コストで高速に学習できます。 + - 学習する重みが少ないため、過学習に強いとされています。 + - 保存されるファイル(LoRAネットワーク)は数十〜数百MBと非常に小さく、管理が容易です。 + - 複数のLoRAを組み合わせて使用することも可能です。 +- **欠点**: モデル全体を学習するわけではないため、Fine-tuningほどの大きな変化は期待できない場合があります。 +- **対応スクリプト**: `sdxl_train_network.py`, `sd3_train_network.py`, `flux_train_network.py` など、`*_train_network.py` という命名規則のスクリプトが対応します。 + +| 特徴 | Fine-tuning | LoRA学習 | +|:---|:---|:---| +| **学習対象** | モデルの全重み | 追加ネットワーク(アダプター)のみ | +| **VRAM/計算コスト**| 大 | 小 | +| **学習時間** | 長 | 短 | +| **ファイルサイズ** | 大(数GB) | 小(数MB〜数百MB) | +| **過学習リスク** | 高 | 低 | +| **適した用途** | 大規模な画風変更、コンセプト学習 | 特定のキャラ、画風の追加学習 | + +一般的に、特定のキャラクターや画風を追加したい場合は**LoRA学習**から試すことが推奨されます。より根本的な画風の変更や、高品質なモデルを目指す場合は**Fine-tuning**が有効な選択肢となります。 + +
+ +--- + +### Fine-tuning for each architecture + +Fine-tuning updates the entire weights of the model, so it has different options and considerations than LoRA tuning. This section describes the fine-tuning scripts for major architectures. + +The basic command structure is common to all architectures. + +```bash +accelerate launch --mixed_precision bf16 {script_name}.py \ + --pretrained_model_name_or_path \ + --dataset_config \ + --output_dir \ + --output_name \ + --save_model_as safetensors \ + --max_train_steps 10000 \ + --learning_rate 1e-5 \ + --optimizer_type AdamW8bit +``` + +
+日本語 + +### 各アーキテクチャのFine-tuning + +Fine-tuningはモデルの重み全体を更新するため、LoRA学習とは異なるオプションや考慮事項があります。ここでは主要なアーキテクチャごとのFine-tuningスクリプトについて説明します。 + +基本的なコマンドの構造は、どのアーキテクチャでも共通です。 + +```bash +accelerate launch --mixed_precision bf16 {script_name}.py \ + --pretrained_model_name_or_path \ + --dataset_config \ + --output_dir \ + --output_name \ + --save_model_as safetensors \ + --max_train_steps 10000 \ + --learning_rate 1e-5 \ + --optimizer_type AdamW8bit +``` + +
+ +#### SDXL (`sdxl_train.py`) + +Performs fine-tuning for SDXL models. It is possible to train both the U-Net and the Text Encoders. + +**Key Options:** + +- `--train_text_encoder`: Includes the weights of the Text Encoders (CLIP ViT-L and OpenCLIP ViT-bigG) in the training. Effective for significant style changes or strongly learning specific concepts. +- `--learning_rate_te1`, `--learning_rate_te2`: Set individual learning rates for each Text Encoder. +- `--block_lr`: Divides the U-Net into 23 blocks and sets a different learning rate for each block. This allows for advanced adjustments, such as strengthening or weakening the learning of specific layers. (Not available in LoRA tuning). + +**Command Example:** + +```bash +accelerate launch --mixed_precision bf16 sdxl_train.py \ + --pretrained_model_name_or_path "sd_xl_base_1.0.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "sdxl_finetuned" \ + --train_text_encoder \ + --learning_rate 1e-5 \ + --learning_rate_te1 5e-6 \ + --learning_rate_te2 2e-6 +``` + +
+日本語 + +#### SDXL (`sdxl_train.py`) + +SDXLモデルのFine-tuningを行います。U-NetとText Encoderの両方を学習させることが可能です。 + +**主要なオプション:** + +- `--train_text_encoder`: Text Encoder(CLIP ViT-LとOpenCLIP ViT-bigG)の重みを学習対象に含めます。画風を大きく変えたい場合や、特定の概念を強く学習させたい場合に有効です。 +- `--learning_rate_te1`, `--learning_rate_te2`: それぞれのText Encoderに個別の学習率を設定します。 +- `--block_lr`: U-Netを23個のブロックに分割し、ブロックごとに異なる学習率を設定できます。特定の層の学習を強めたり弱めたりする高度な調整が可能です。(LoRA学習では利用できません) + +**コマンド例:** + +```bash +accelerate launch --mixed_precision bf16 sdxl_train.py \ + --pretrained_model_name_or_path "sd_xl_base_1.0.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "sdxl_finetuned" \ + --train_text_encoder \ + --learning_rate 1e-5 \ + --learning_rate_te1 5e-6 \ + --learning_rate_te2 2e-6 +``` + +
+ +#### SD3 (`sd3_train.py`) + +Performs fine-tuning for Stable Diffusion 3 Medium models. SD3 consists of three Text Encoders (CLIP-L, CLIP-G, T5-XXL) and a MMDiT (equivalent to U-Net), which can be targeted for training. + +**Key Options:** + +- `--train_text_encoder`: Enables training for CLIP-L and CLIP-G. +- `--train_t5xxl`: Enables training for T5-XXL. T5-XXL is a very large model and requires a lot of VRAM for training. +- `--blocks_to_swap`: A memory optimization feature to reduce VRAM usage. It swaps some blocks of the MMDiT to CPU memory during training. Useful for using larger batch sizes in low VRAM environments. (Also available in LoRA tuning). +- `--num_last_block_to_freeze`: Freezes the weights of the last N blocks of the MMDiT, excluding them from training. Useful for maintaining model stability while focusing on learning in the lower layers. + +**Command Example:** + +```bash +accelerate launch --mixed_precision bf16 sd3_train.py \ + --pretrained_model_name_or_path "sd3_medium.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "sd3_finetuned" \ + --train_text_encoder \ + --learning_rate 4e-6 \ + --blocks_to_swap 10 +``` + +
+日本語 + +#### SD3 (`sd3_train.py`) + +Stable Diffusion 3 MediumモデルのFine-tuningを行います。SD3は3つのText Encoder(CLIP-L, CLIP-G, T5-XXL)とMMDiT(U-Netに相当)で構成されており、これらを学習対象にできます。 + +**主要なオプション:** + +- `--train_text_encoder`: CLIP-LとCLIP-Gの学習を有効にします。 +- `--train_t5xxl`: T5-XXLの学習を有効にします。T5-XXLは非常に大きなモデルのため、学習には多くのVRAMが必要です。 +- `--blocks_to_swap`: VRAM使用量を削減するためのメモリ最適化機能です。MMDiTの一部のブロックを学習中にCPUメモリに退避(スワップ)させます。VRAMが少ない環境で大きなバッチサイズを使いたい場合に有効です。(LoRA学習でも利用可能) +- `--num_last_block_to_freeze`: MMDiTの最後のNブロックの重みを凍結し、学習対象から除外します。モデルの安定性を保ちつつ、下位層を中心に学習させたい場合に有効です。 + +**コマンド例:** + +```bash +accelerate launch --mixed_precision bf16 sd3_train.py \ + --pretrained_model_name_or_path "sd3_medium.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "sd3_finetuned" \ + --train_text_encoder \ + --learning_rate 4e-6 \ + --blocks_to_swap 10 +``` + +
+ +#### FLUX.1 (`flux_train.py`) + +Performs fine-tuning for FLUX.1 models. FLUX.1 is internally composed of two Transformer blocks (Double Blocks, Single Blocks). + +**Key Options:** + +- `--blocks_to_swap`: Similar to SD3, this feature swaps Transformer blocks to the CPU for memory optimization. +- `--blockwise_fused_optimizers`: An experimental feature that aims to streamline training by applying individual optimizers to each block. + +**Command Example:** + +```bash +accelerate launch --mixed_precision bf16 flux_train.py \ + --pretrained_model_name_or_path "FLUX.1-dev.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "flux1_finetuned" \ + --learning_rate 1e-5 \ + --blocks_to_swap 18 +``` + +
+日本語 + +#### FLUX.1 (`flux_train.py`) + +FLUX.1モデルのFine-tuningを行います。FLUX.1は内部的に2つのTransformerブロック(Double Blocks, Single Blocks)で構成されています。 + +**主要なオプション:** + +- `--blocks_to_swap`: SD3と同様に、メモリ最適化のためにTransformerブロックをCPUにスワップする機能です。 +- `--blockwise_fused_optimizers`: 実験的な機能で、各ブロックに個別のオプティマイザを適用し、学習を効率化することを目指します。 + +**コマンド例:** + +```bash +accelerate launch --mixed_precision bf16 flux_train.py \ + --pretrained_model_name_or_path "FLUX.1-dev.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "flux1_finetuned" \ + --learning_rate 1e-5 \ + --blocks_to_swap 18 +``` + +
+ +#### Lumina (`lumina_train.py`) + +Performs fine-tuning for Lumina-Next DiT models. + +**Key Options:** + +- `--use_flash_attn`: Enables Flash Attention to speed up computation. +- `lumina_train.py` is relatively new, and many of its options are shared with other scripts. Training can be performed following the basic command pattern. + +**Command Example:** + +```bash +accelerate launch --mixed_precision bf16 lumina_train.py \ + --pretrained_model_name_or_path "Lumina-Next-DiT-B.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "lumina_finetuned" \ + --learning_rate 1e-5 +``` + +
+日本語 + +#### Lumina (`lumina_train.py`) + +Lumina-Next DiTモデルのFine-tuningを行います。 + +**主要なオプション:** + +- `--use_flash_attn`: Flash Attentionを有効にし、計算を高速化します。 +- `lumina_train.py`は比較的新しく、オプションは他のスクリプトと共通化されている部分が多いです。基本的なコマンドパターンに従って学習を行えます。 + +**コマンド例:** + +```bash +accelerate launch --mixed_precision bf16 lumina_train.py \ + --pretrained_model_name_or_path "Lumina-Next-DiT-B.safetensors" \ + --dataset_config "dataset_config.toml" \ + --output_dir "output" \ + --output_name "lumina_finetuned" \ + --learning_rate 1e-5 +``` + +
+ +--- + +### Differences between Fine-tuning and LoRA tuning per architecture + +| Architecture | Key Features/Options Specific to Fine-tuning | Main Differences from LoRA tuning | +|:---|:---|:---| +| **SDXL** | `--block_lr` | Only fine-tuning allows for granular control over the learning rate for each U-Net block. | +| **SD3** | `--train_text_encoder`, `--train_t5xxl`, `--num_last_block_to_freeze` | Only fine-tuning can train the entire Text Encoders. LoRA only trains the adapter parts. | +| **FLUX.1** | `--blockwise_fused_optimizers` | Since fine-tuning updates the entire model's weights, more experimental optimizer options are available. | +| **Lumina** | (Few specific options) | Basic training options are common, but fine-tuning differs in that it updates the entire model's foundation. | + +
+日本語 + +### アーキテクチャごとのFine-tuningとLoRA学習の違い + +| アーキテクチャ | Fine-tuning特有の主要機能・オプション | LoRA学習との主な違い | +|:---|:---|:---| +| **SDXL** | `--block_lr` | U-Netのブロックごとに学習率を細かく制御できるのはFine-tuningのみです。 | +| **SD3** | `--train_text_encoder`, `--train_t5xxl`, `--num_last_block_to_freeze` | Text Encoder全体を学習対象にできるのはFine-tuningです。LoRAではアダプター部分のみ学習します。 | +| **FLUX.1** | `--blockwise_fused_optimizers` | Fine-tuningではモデル全体の重みを更新するため、より実験的なオプティマイザの選択肢が用意されています。 | +| **Lumina** | (特有のオプションは少ない) | 基本的な学習オプションは共通ですが、Fine-tuningはモデルの基盤全体を更新する点で異なります。 | + +
diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md new file mode 100644 index 000000000..b8207cb00 --- /dev/null +++ b/docs/flux_train_network.md @@ -0,0 +1,709 @@ +Status: reviewed + +# LoRA Training Guide for FLUX.1 using `flux_train_network.py` / `flux_train_network.py` を用いたFLUX.1モデルのLoRA学習ガイド + +This document explains how to train LoRA models for the FLUX.1 model using `flux_train_network.py` included in the `sd-scripts` repository. + +
+日本語 + +このドキュメントでは、`sd-scripts`リポジトリに含まれる`flux_train_network.py`を使用して、FLUX.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 + +
+ +## 1. Introduction / はじめに + +`flux_train_network.py` trains additional networks such as LoRA on the FLUX.1 model, which uses a transformer-based architecture different from Stable Diffusion. Two text encoders, CLIP-L and T5-XXL, and a dedicated AutoEncoder are used. + +This guide assumes you know the basics of LoRA training. For common options see [train_network.py](train_network.md) and [sdxl_train_network.py](sdxl_train_network.md). + +**Prerequisites:** + +* The repository is cloned and the Python environment is ready. +* A training dataset is prepared. See the dataset configuration guide. + +
+日本語 + +`flux_train_network.py`は、FLUX.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。FLUX.1はStable Diffusionとは異なるアーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) + +
+ +## 2. Differences from `train_network.py` / `train_network.py` との違い + +`flux_train_network.py` is based on `train_network.py` but adapted for FLUX.1. Main differences include: + +* **Target model:** FLUX.1 model (dev or schnell version). +* **Model structure:** Unlike Stable Diffusion, FLUX.1 uses a Transformer-based architecture with two text encoders (CLIP-L and T5-XXL) and a dedicated AutoEncoder (AE) instead of VAE. +* **Required arguments:** Additional arguments for FLUX.1 model, CLIP-L, T5-XXL, and AE model files. +* **Incompatible options:** Some Stable Diffusion-specific arguments (e.g., `--v2`, `--clip_skip`, `--max_token_length`) are not used in FLUX.1 training. +* **FLUX.1-specific arguments:** Additional arguments for FLUX.1-specific training parameters like timestep sampling and guidance scale. + +
+日本語 + +`flux_train_network.py`は`train_network.py`をベースに、FLUX.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** FLUX.1モデル(dev版またはschnell版)を対象とします。 +* **モデル構造:** Stable Diffusionとは異なり、FLUX.1はTransformerベースのアーキテクチャを持ちます。Text EncoderとしてCLIP-LとT5-XXLの二つを使用し、VAEの代わりに専用のAutoEncoder (AE) を使用します。 +* **必須の引数:** FLUX.1モデル、CLIP-L、T5-XXL、AEの各モデルファイルを指定する引数が追加されています。 +* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)はFLUX.1の学習では使用されません。 +* **FLUX.1特有の引数:** タイムステップのサンプリング方法やガイダンススケールなど、FLUX.1特有の学習パラメータを指定する引数が追加されています。 + +
+ +## 3. Preparation / 準備 + +Before starting training you need: + +1. **Training script:** `flux_train_network.py` +2. **FLUX.1 model file:** Base FLUX.1 model `.safetensors` file (e.g., `flux1-dev.safetensors`). +3. **Text Encoder model files:** + - CLIP-L model `.safetensors` file (e.g., `clip_l.safetensors`) + - T5-XXL model `.safetensors` file (e.g., `t5xxl.safetensors`) +4. **AutoEncoder model file:** FLUX.1-compatible AE model `.safetensors` file (e.g., `ae.safetensors`). +5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration (e.g., `my_flux_dataset_config.toml`). + +### Downloading Required Models + +To train FLUX.1 models, you need to download the following model files: + +- **DiT, AE**: Download from the [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) repository. Use `flux1-dev.safetensors` and `ae.safetensors`. The weights in the subfolder are in Diffusers format and cannot be used. +- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: Download from the [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) repository. Please use `t5xxl_fp16.safetensors` for T5-XXL. Thanks to ComfyUI for providing these models. + +To train Chroma models, you need to download the Chroma model file from the following repository: + +- **Chroma Base**: Download from the [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) repository. Use `Chroma.safetensors`. + +We have tested Chroma training with the weights from the [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) repository. + +AE and T5-XXL models are same as FLUX.1, so you can use the same files. CLIP-L model is not used for Chroma training, so you can omit the `--clip_l` argument. + +
+日本語 + +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `flux_train_network.py` +2. **FLUX.1モデルファイル:** 学習のベースとなるFLUX.1モデルの`.safetensors`ファイル(例: `flux1-dev.safetensors`)。 +3. **Text Encoderモデルファイル:** + - CLIP-Lモデルの`.safetensors`ファイル。例として`clip_l.safetensors`を使用します。 + - T5-XXLモデルの`.safetensors`ファイル。例として`t5xxl.safetensors`を使用します。 +4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。例として`my_flux_dataset_config.toml`を使用します。 + +**必要なモデルのダウンロード** + +FLUX.1モデルを学習するためには、以下のモデルファイルをダウンロードする必要があります。 + +- **DiT, AE**: [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) リポジトリからダウンロードします。`flux1-dev.safetensors`と`ae.safetensors`を使用してください。サブフォルダ内の重みはDiffusers形式であり、使用できません。 +- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) リポジトリからダウンロードします。T5-XXLには`t5xxl_fp16.safetensors`を使用してください。これらのモデルを提供いただいたComfyUIに感謝します。 + +Chromaモデルを学習する場合は、以下のリポジトリからChromaモデルファイルをダウンロードする必要があります。 + +- **Chroma Base**: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) リポジトリからダウンロードします。`Chroma.safetensors`を使用してください。 + +Chromaの学習のテストは [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) リポジトリの重みを使用して行いました。 + +AEとT5-XXLモデルはFLUX.1と同じものを使用できるため、同じファイルを使用します。CLIP-LモデルはChroma学習では使用されないため、`--clip_l`引数は省略できます。 + +
+ +## 4. Running the Training / 学習の実行 + +Run `flux_train_network.py` from the terminal with FLUX.1 specific arguments. Here's a basic command example: + +```bash +accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ + --pretrained_model_name_or_path="" \ + --clip_l="" \ + --t5xxl="" \ + --ae="" \ + --dataset_config="my_flux_dataset_config.toml" \ + --output_dir="" \ + --output_name="my_flux_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_flux \ + --network_dim=16 \ + --network_alpha=1 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW8bit" \ + --lr_scheduler="constant" \ + --sdpa \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="fp16" \ + --gradient_checkpointing \ + --guidance_scale=1.0 \ + --timestep_sampling="flux_shift" \ + --model_prediction_type="raw" \ + --blocks_to_swap=18 \ + --cache_text_encoder_outputs \ + --cache_latents +``` + +### Training Chroma Models + +If you want to train a Chroma model, specify `--model_type=chroma`. Chroma does not use CLIP-L, so the `--clip_l` argument is not needed. T5XXL and AE are same as FLUX.1. The command would look like this: + +```bash +accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ + --pretrained_model_name_or_path="" \ + --model_type=chroma \ + --t5xxl="" \ + --ae="" \ + --dataset_config="my_flux_dataset_config.toml" \ + --output_dir="" \ + --output_name="my_chroma_lora" \ + --guidance_scale=0.0 \ + --timestep_sampling="sigmoid" \ + --apply_t5_attn_mask \ + ... +``` + +Note that for Chroma models, `--guidance_scale=0.0` is required to disable guidance scale, and `--apply_t5_attn_mask` is needed to apply attention masks for T5XXL Text Encoder. + +The sample image generation during training requires specifying a negative prompt. Also, set `--g 0` to disable embedded guidance scale and `--l 4.0` to set the CFG scale. For example: + +``` +Japanese shrine in the summer forest. --n low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors --w 512 --h 512 --d 1 --l 4.0 --g 0.0 --s 20 +``` + +
+日本語 + +学習は、ターミナルから`flux_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、FLUX.1特有の引数を指定する必要があります。 + +コマンドラインの例は英語のドキュメントを参照してください。 + +#### Chromaモデルの学習 + +Chromaモデルを学習したい場合は、`--model_type=chroma`を指定します。ChromaはCLIP-Lを使用しないため、`--clip_l`引数は不要です。T5XXLとAEはFLUX.1と同様です。 + +コマンドラインの例は英語のドキュメントを参照してください。 + +学習中のサンプル画像生成には、ネガティブプロンプトを指定してください。また `--g 0` を指定して埋め込みガイダンススケールを無効化し、`--l 4.0` を指定してCFGスケールを設定します。 + +
+ +### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 + +The script adds FLUX.1 specific arguments. For common arguments (like `--output_dir`, `--output_name`, `--network_module`, etc.), see the [`train_network.py` guide](train_network.md). + +#### Model-related [Required] + +* `--pretrained_model_name_or_path=""` **[Required]** + - Specifies the path to the base FLUX.1 or Chroma model `.safetensors` file. Diffusers format directories are not currently supported. +* `--model_type=` + - Specifies the type of base model for training. Choose from `flux` or `chroma`. Default is `flux`. +* `--clip_l=""` **[Required when flux is selected]** + - Specifies the path to the CLIP-L Text Encoder model `.safetensors` file. Not needed when `--model_type=chroma`. +* `--t5xxl=""` **[Required]** + - Specifies the path to the T5-XXL Text Encoder model `.safetensors` file. +* `--ae=""` **[Required]** + - Specifies the path to the FLUX.1-compatible AutoEncoder model `.safetensors` file. + +#### FLUX.1 Training Parameters + +* `--guidance_scale=` + - FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `1.0` to disable guidance scale. Default is `3.5`, so be sure to specify this. Usually ignored for schnell version. + - Chroma requires `--guidance_scale=0.0` to disable guidance scale. +* `--timestep_sampling=` + - Specifies the sampling method for timesteps (noise levels) during training. Choose from `sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`. Default is `sigma`. Recommended is `flux_shift`. For Chroma models, `sigmoid` is recommended. +* `--sigmoid_scale=` + - Scale factor when `timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default and recommended value is `1.0`. +* `--model_prediction_type=` + - Specifies what the model predicts. Choose from `raw` (use prediction as-is), `additive` (add to noise input), `sigma_scaled` (apply sigma scaling). Default is `sigma_scaled`. Recommended is `raw`. +* `--discrete_flow_shift=` + - Specifies the shift value for the scheduler used in Flow Matching. Default is `3.0`. This value is ignored when `timestep_sampling` is set to other than `shift`. + +#### Memory/Speed Related + +* `--fp8_base` + - Enables training in FP8 format for FLUX.1, CLIP-L, and T5-XXL. This can significantly reduce VRAM usage, but the training results may vary. +* `--blocks_to_swap=` **[Experimental Feature]** + - Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`. + - Cannot be used with `--cpu_offload_checkpointing`. +* `--cache_text_encoder_outputs` + - Caches the outputs of CLIP-L and T5-XXL. This reduces memory usage. +* `--cache_latents`, `--cache_latents_to_disk` + - Caches the outputs of AE. Similar functionality to [sdxl_train_network.py](sdxl_train_network.md). + +#### Incompatible/Deprecated Arguments + +* `--v2`, `--v_parameterization`, `--clip_skip`: These are Stable Diffusion-specific arguments and are not used in FLUX.1 training. +* `--max_token_length`: This is an argument for Stable Diffusion v1/v2. For FLUX.1, use `--t5xxl_max_token_length`. +* `--split_mode`: Deprecated argument. Use `--blocks_to_swap` instead. + +
+日本語 + +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のFLUX.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。 + +コマンドラインの例と詳細な引数の説明は英語のドキュメントを参照してください。 + +
+ +### 4.2. Starting Training / 学習の開始 + +Training begins once you run the command with the required options. Log checking is the same as in [`train_network.py`](train_network.md#32-starting-the-training--学習の開始). + +
+日本語 + +必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 + +
+ +## 5. Using the Trained Model / 学習済みモデルの利用 + +After training, a LoRA model file is saved in `output_dir` and can be used in inference environments supporting FLUX.1 (e.g. ComfyUI + Flux nodes). + +
+日本語 + +学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_flux_lora.safetensors`)が保存されます。このファイルは、FLUX.1モデルに対応した推論環境(例: ComfyUI + ComfyUI-FluxNodes)で使用できます。 + +
+ +## 6. Advanced Settings / 高度な設定 + +### 6.1. VRAM Usage Optimization / VRAM使用量の最適化 + +FLUX.1 is a relatively large model, so GPUs without sufficient VRAM require optimization. Here are settings to reduce VRAM usage (with `--fp8_base`): + +#### Recommended Settings by GPU Memory + +| GPU Memory | Recommended Settings | +|------------|---------------------| +| 24GB VRAM | Basic settings work fine (batch size 2) | +| 16GB VRAM | Set batch size to 1 and use `--blocks_to_swap` | +| 12GB VRAM | Use `--blocks_to_swap 16` and 8bit AdamW | +| 10GB VRAM | Use `--blocks_to_swap 22`, recommend fp8 format for T5XXL | +| 8GB VRAM | Use `--blocks_to_swap 28`, recommend fp8 format for T5XXL | + +#### Key VRAM Reduction Options + +- **`--fp8_base`**: Enables training in FP8 format. + +- **`--blocks_to_swap `**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. FLUX.1 supports up to 35 blocks for swapping. + +- **`--cpu_offload_checkpointing`**: Offloads gradient checkpoints to CPU. Can reduce VRAM usage by up to 1GB but decreases training speed by about 15%. Cannot be used with `--blocks_to_swap`. Chroma models do not support this option. + +- **Using Adafactor optimizer**: Can reduce VRAM usage more than 8bit AdamW: + ``` + --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 + ``` + +- **Using T5XXL fp8 format**: For GPUs with less than 10GB VRAM, using fp8 format T5XXL checkpoints is recommended. Download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (use without `scaled`). + +- **FP8/FP16 Mixed Training [Experimental]**: Specify `--fp8_base_unet` to train the FLUX.1 model in FP8 format while training Text Encoders (CLIP-L/T5XXL) in BF16/FP16 format. This can further reduce VRAM usage. + +
+日本語 + +FLUX.1モデルは比較的大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。VRAM使用量を削減するための設定の詳細は英語のドキュメントを参照してください。 + +主要なVRAM削減オプション: +- `--fp8_base`: FP8形式での学習を有効化 +- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ +- `--cpu_offload_checkpointing`: 勾配チェックポイントをCPUにオフロード +- Adafactorオプティマイザの使用 +- T5XXLのfp8形式の使用 +- FP8/FP16混合学習(実験的機能) + +
+ +### 6.2. Important FLUX.1 LoRA Training Settings / FLUX.1 LoRA学習の重要な設定 + +FLUX.1 training has many unknowns, and several settings can be specified with arguments: + +#### Timestep Sampling Methods + +The `--timestep_sampling` option specifies how timesteps (0-1) are sampled: + +- `sigma`: Sigma-based like SD3 +- `uniform`: Uniform random +- `sigmoid`: Sigmoid of normal distribution random (similar to x-flux, AI-toolkit) +- `shift`: Sigmoid value of normal distribution random with shift. The `--discrete_flow_shift` setting is used to shift the sigmoid value. +- `flux_shift`: Shift sigmoid value of normal distribution random according to resolution (similar to FLUX.1 dev inference). + +`--discrete_flow_shift` only applies when `--timestep_sampling` is set to `shift`. + +#### Model Prediction Processing + +The `--model_prediction_type` option specifies how to interpret and process model predictions: + +- `raw`: Use as-is (similar to x-flux) **[Recommended]** +- `additive`: Add to noise input +- `sigma_scaled`: Apply sigma scaling (similar to SD3) + +#### Recommended Settings + +Based on experiments, the following settings work well: +``` +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 +``` + +For Chroma models, the following settings are recommended: +``` +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 0.0 +``` + +**About Guidance Scale**: FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `--guidance_scale 1.0` to disable guidance scale. Chroma requires `--guidance_scale 0.0` to disable guidance scale because it is not distilled. + +
+日本語 + +FLUX.1の学習には多くの未知の点があり、いくつかの設定は引数で指定できます。詳細な説明とコマンドラインの例は英語のドキュメントを参照してください。 + +主要な設定オプション: +- タイムステップのサンプリング方法(`--timestep_sampling`) +- モデル予測の処理方法(`--model_prediction_type`) +- 推奨設定の組み合わせ + +
+ +### 6.3. Layer-specific Rank Configuration / 各層に対するランク指定 + +You can specify different ranks (network_dim) for each layer of FLUX.1. This allows you to emphasize or disable LoRA effects for specific layers. + +Specify the following network_args to set ranks for each layer. Setting 0 disables LoRA for that layer: + +| network_args | Target Layer | +|--------------|--------------| +| img_attn_dim | DoubleStreamBlock img_attn | +| txt_attn_dim | DoubleStreamBlock txt_attn | +| img_mlp_dim | DoubleStreamBlock img_mlp | +| txt_mlp_dim | DoubleStreamBlock txt_mlp | +| img_mod_dim | DoubleStreamBlock img_mod | +| txt_mod_dim | DoubleStreamBlock txt_mod | +| single_dim | SingleStreamBlock linear1 and linear2 | +| single_mod_dim | SingleStreamBlock modulation | + +Example usage: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" "img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +``` + +To apply LoRA to FLUX conditioning layers, specify `in_dims` in network_args as a comma-separated list of 5 numbers: + +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The example above applies LoRA to all conditioning layers with ranks of 4 for `img_in` and `txt_in`, and ranks of 2 for others. + +
+日本語 + +FLUX.1の各層に対して異なるランク(network_dim)を指定できます。これにより、特定の層に対してLoRAの効果を強調したり、無効化したりできます。 + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 6.4. Block Selection for Training / 学習するブロックの指定 + +You can specify which blocks to train using `train_double_block_indices` and `train_single_block_indices` in network_args. Indices are 0-based. Default is to train all blocks if omitted. + +Specify indices as integer lists like `0,1,5,8` or integer ranges like `0,1,4-5,7`: +- Double blocks: 19 blocks, valid range 0-18 +- Single blocks: 38 blocks, valid range 0-37 +- Specify `all` to train all blocks +- Specify `none` to skip training blocks + +Example usage: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +Or: +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +
+日本語 + +FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_single_block_indices`を指定することで、学習するブロックを指定できます。 + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 6.5. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 + +You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer. + +These settings are specified via the `network_args` argument. + +* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`. + * Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`. +* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`. + * Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`. + +**Notes:** + +* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings. +* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied. +* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`). + +
+日本語 + +正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。 + +これらの設定は `network_args` 引数で指定します。 + +* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。 +* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。 +**注意点:** + +* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。 +* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。 +* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。 + +
+ +### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート + +FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA: + +- To train only FLUX.1: specify `--network_train_unet_only` +- To train FLUX.1 and CLIP-L: omit `--network_train_unet_only` +- To train FLUX.1, CLIP-L, and T5XXL: omit `--network_train_unet_only` and add `--network_args "train_t5xxl=True"` + +You can specify individual learning rates for CLIP-L and T5XXL with `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5` sets the first value for CLIP-L and the second for T5XXL. Specifying one value uses the same learning rate for both. If `--text_encoder_lr` is not specified, the default `--learning_rate` is used for both. + +
+日本語 + +FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポートしています。 + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 6.7. Multi-Resolution Training / マルチ解像度トレーニング + +You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution. + +Configuration file example: +```toml +[general] +# Common settings +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" + +[[datasets]] +# First resolution settings +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/directory" + num_repeats = 1 + +[[datasets]] +# Second resolution settings +batch_size = 3 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "path/to/image/directory" + num_repeats = 1 +``` + +
+日本語 + +データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。 + +設定ファイルの例は英語のドキュメントを参照してください。 + +
+ +### 6.8. Validation / 検証 + +You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. + +To set up validation, add a `validation_split` and optionally `validation_seed` to your dataset configuration TOML file. + +```toml +validation_seed = 42 # [Optional] Validation seed, otherwise uses training seed for validation split . +enable_bucket = true +resolution = [1024, 1024] + +[[datasets]] + [[datasets.subsets]] + # This directory will use 100% of the images for training + image_dir = "path/to/image/directory" + +[[datasets]] +validation_split = 0.1 # Split between 0.0 and 1.0 where 1.0 will use the full subset as a validation dataset + + [[datasets.subsets]] + # This directory will split 10% to validation and 90% to training + image_dir = "path/to/image/second-directory" + +[[datasets]] +validation_split = 1.0 # Will use this full subset as a validation subset. + + [[datasets.subsets]] + # This directory will use the 100% to validation and 0% to training + image_dir = "path/to/image/full_validation" +``` + +**Notes:** + +* Validation loss calculation uses fixed timestep sampling and random seeds to reduce loss variation due to randomness for more stable evaluation. +* Currently, validation loss is not supported when using Schedule-Free optimizers (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`). + +
+日本語 + +学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。 + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +## 7. Additional Options / 追加オプション + +### 7.1. Other FLUX.1-specific Options / その他のFLUX.1特有のオプション + +- **T5 Attention Mask Application**: Specify `--apply_t5_attn_mask` to apply attention masks during T5XXL Text Encoder training and inference. Not recommended due to limited inference environment support. **For Chroma models, this option is required.** + +- **IP Noise Gamma**: Use `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` to adjust Input Perturbation noise gamma values during training. See Stable Diffusion 3 training options for details. + +- **LoRA-GGPO Support**: Use LoRA-GGPO (Gradient Group Proportion Optimizer) to stabilize LoRA training: + ```bash + --network_args "ggpo_sigma=0.03" "ggpo_beta=0.01" + ``` + +- **Q/K/V Projection Layer Splitting [Experimental]**: Specify `--network_args "split_qkv=True"` to individually split and apply LoRA to Q/K/V (and SingleStreamBlock Text) projection layers within Attention layers. + +
+日本語 + +その他のFLUX.1特有のオプション: +- T5 Attention Maskの適用(Chromaモデルでは必須) +- IPノイズガンマ +- LoRA-GGPOサポート +- Q/K/V射影層の分割(実験的機能) + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 7.2. Dataset-related Additional Options / データセット関連の追加オプション + +#### Interpolation Method for Resizing + +You can specify the interpolation method when resizing dataset images to training resolution. Specify `interpolation_type` in the `[[datasets]]` or `[general]` section of the dataset configuration TOML file. + +Available values: `bicubic` (default), `bilinear`, `lanczos`, `nearest`, `area` + +```toml +[[datasets]] +resolution = [1024, 1024] +enable_bucket = true +interpolation_type = "lanczos" # Example: Use Lanczos interpolation +# ... +``` + +
+日本語 + +データセットの画像を学習解像度にリサイズする際の補間方法を指定できます。 + +設定方法とオプションの詳細は英語のドキュメントを参照してください。 + +
+ +### 7.3. Other Training Options / その他の学習オプション + +- **`--controlnet_model_name_or_path`**: Specifies the path to a ControlNet model compatible with FLUX.1. This allows for training a LoRA that works in conjunction with ControlNet. This is an advanced feature and requires a compatible ControlNet model. + +- **`--loss_type`**: Specifies the loss function for training. The default is `l2`. + - `l1`: L1 loss. + - `l2`: L2 loss (mean squared error). + - `huber`: Huber loss. + - `smooth_l1`: Smooth L1 loss. + +- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: These are parameters for Huber loss. They are used when `--loss_type` is set to `huber` or `smooth_l1`. + +- **`--t5xxl_max_token_length`**: Specifies the maximum token length for the T5-XXL text encoder. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md). + +- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: These options allow you to adjust the loss weighting for each timestep. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md). + +- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. For details, refer to the [`sdxl_train_network.md` guide](sdxl_train_network.md). + +
+日本語 + +- **`--controlnet_model_name_or_path`**: FLUX.1互換のControlNetモデルへのパスを指定します。これにより、ControlNetと連携して動作するLoRAを学習できます。これは高度な機能であり、互換性のあるControlNetモデルが必要です。 +- **`--loss_type`**: 学習に用いる損失関数を指定します。デフォルトは `l2` です。 + - `l1`: L1損失。 + - `l2`: L2損失(平均二乗誤差)。 + - `huber`: Huber損失。 + - `smooth_l1`: Smooth L1損失。 +- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: これらはHuber損失のパラメータです。`--loss_type` が `huber` または `smooth_l1` の場合に使用されます。 +- **`--t5xxl_max_token_length`**: T5-XXLテキストエンコーダの最大トークン長を指定します。詳細は [`sd3_train_network.md` ガイド](sd3_train_network.md) を参照してください。 +- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: これらのオプションは、各タイムステップの損失の重み付けを調整するために使用されます。詳細は [`sd3_train_network.md` ガイド](sd3_train_network.md) を参照してください。 +- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップを融合してVRAM使用量を削減します。詳細は [`sdxl_train_network.md` ガイド](sdxl_train_network.md) を参照してください。 + +
+ +## 8. Related Tools / 関連ツール + +Several related scripts are provided for models trained with `flux_train_network.py` and to assist with the training process: + +* **`networks/flux_extract_lora.py`**: Extracts LoRA models from the difference between trained and base models. +* **`convert_flux_lora.py`**: Converts trained LoRA models to other formats like Diffusers (AI-Toolkit) format. When trained with Q/K/V split option, converting with this script can reduce model size. +* **`networks/flux_merge_lora.py`**: Merges trained LoRA models into FLUX.1 base models. +* **`flux_minimal_inference.py`**: Simple inference script for generating images with trained LoRA models. You can specify `flux` or `chroma` with the `--model_type` argument. + +
+日本語 + +`flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています: + +* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出。 +* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など他の形式に変換。 +* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージ。 +* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するシンプルな推論スクリプト。 + `--model_type` 引数で `flux` または `chroma` を指定できます。 + +
+ +## 9. Others / その他 + +`flux_train_network.py` includes many features common with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these features, refer to the [`train_network.py` guide](train_network.md#5-other-features--その他の機能) or the script help (`python flux_train_network.py --help`). + +
+日本語 + +`flux_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python flux_train_network.py --help`) を参照してください。 + +
diff --git a/docs/gen_img_README-ja.md b/docs/gen_img_README-ja.md index 8f4442d00..ca2eeab2a 100644 --- a/docs/gen_img_README-ja.md +++ b/docs/gen_img_README-ja.md @@ -3,7 +3,7 @@ SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、Control # 概要 * Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。 -* SD 1.xおよび2.x (base/v-parameterization)モデルに対応。 +* SD 1.x、2.x (base/v-parameterization)、およびSDXLモデルに対応。 * txt2img、img2img、inpaintingに対応。 * 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。 * プロンプト1行あたりの生成枚数を指定可能。 @@ -96,14 +96,20 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> - `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。 +- `--v1`:Stable Diffusion 1.x系のモデルを使う場合に指定します。これがデフォルトの動作です。 + - `--v2`:Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。 +- `--sdxl`:Stable Diffusion XLモデルを使う場合に指定します。 + - `--v_parameterization`:v-parameterizationを使うモデルを使う場合に指定します(`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。 - `--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。 + `--v2`や`--sdxl`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。 - `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。 +- `--tokenizer_cache_dir`:トークナイザーのキャッシュディレクトリを指定します(オフライン利用のため)。 + ## 画像生成と出力 - `--interactive`:インタラクティブモードで動作します。プロンプトを入力すると画像が生成されます。 @@ -112,6 +118,10 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> - `--from_file <プロンプトファイル名>`:プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション(後述)で指定できます。 +- `--from_module <モジュールファイル>`:Pythonモジュールからプロンプトを読み込みます。モジュールは`get_prompter(args, pipe, networks)`関数を実装している必要があります。 + +- `--prompter_module_args`:prompterモジュールに渡す追加の引数を指定します。 + - `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。 - `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。 @@ -132,6 +142,24 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> - `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。 +- `--emb_normalize_mode`:embedding正規化モードを指定します。"original"(デフォルト)、"abs"、"none"から選択できます。プロンプトの重みの正規化方法に影響します。 + +## SDXL固有のオプション + +SDXL モデル(`--sdxl`フラグ付き)を使用する場合、追加のコンディショニングオプションが利用できます: + +- `--original_height`:SDXL コンディショニング用の元の高さを指定します。これはモデルの対象解像度の理解に影響します。 + +- `--original_width`:SDXL コンディショニング用の元の幅を指定します。これはモデルの対象解像度の理解に影響します。 + +- `--original_height_negative`:SDXL ネガティブコンディショニング用の元の高さを指定します。 + +- `--original_width_negative`:SDXL ネガティブコンディショニング用の元の幅を指定します。 + +- `--crop_top`:SDXL コンディショニング用のクロップ上オフセットを指定します。 + +- `--crop_left`:SDXL コンディショニング用のクロップ左オフセットを指定します。 + ## メモリ使用量や生成速度の調整 - `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。 @@ -139,8 +167,16 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> - `--vae_batch_size `:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。 VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。 +- `--vae_slices <スライス数>`:VAE処理時に画像をスライスに分割してVRAM使用量を削減します。None(デフォルト)で分割なし。16や32のような値が推奨されます。有効にすると処理が遅くなりますが、VRAM使用量が少なくなります。 + +- `--no_half_vae`:VAE処理でfp16/bf16精度の使用を防ぎます。代わりにfp32を使用します。VAE関連の問題やアーティファクトが発生した場合に使用してください。 + - `--xformers`:xformersを使う場合に指定します。 +- `--sdpa`:最適化のためにPyTorch 2のscaled dot-product attentionを使用します。 + +- `--diffusers_xformers`:Diffusers経由でxformersを使用します(注:Hypernetworksと互換性がありません)。 + - `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。 - `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。 @@ -157,6 +193,12 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> - `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。 +- `--network_regional_mask_max_color_codes`:リージョナルマスクに使用する色コードの最大数を指定します。指定されていない場合、マスクはチャンネルごとに適用されます。Regional LoRAと組み合わせて、マスク内の色で定義できるリージョン数を制御するために使用されます。 + +- `--network_args`:key=value形式でネットワークモジュールに渡す追加引数を指定します。例: `--network_args "alpha=1.0,dropout=0.1"`。 + +- `--network_merge_n_models`:ネットワークマージを使用する場合、マージするモデル数を指定します(全ての読み込み済みネットワークをマージする代わりに)。 + # 主なオプションの指定例 次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。 @@ -235,7 +277,9 @@ python gen_img_diffusers.py --ckpt model.safetensors - `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。 -- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります。 +- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名の前に追加されます(img2imgモード用)。 + +- `--clip_vision_strength`:指定した強度でimg2img用のCLIP Vision Conditioningを有効にします。CLIP Visionモデルを使用して入力画像からのコンディショニングを強化します。 ## コマンドラインからの実行例 @@ -306,7 +350,9 @@ img2imgと併用できません。 - `--highres_fix_upscaler`:2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。 - `--highres_fix_upscaler_args`:`--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。 - `tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。 + `tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。 + +- `--highres_fix_disable_control_net`:Highres fixの2nd stageでControlNetを無効にします。デフォルトでは、ControlNetは両ステージで使用されます。 コマンドラインの例です。 @@ -319,6 +365,34 @@ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt --highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5 ``` +## Deep Shrink + +Deep Shrinkは、異なるタイムステップで異なる深度のUNetを使用して生成プロセスを最適化する技術です。生成品質と効率を向上させることができます。 + +以下のオプションがあります: + +- `--ds_depth_1`:第1フェーズでこの深度のDeep Shrinkを有効にします。有効な値は0から8です。 + +- `--ds_timesteps_1`:このタイムステップまでDeep Shrink深度1を適用します。デフォルトは650です。 + +- `--ds_depth_2`:Deep Shrinkの第2フェーズの深度を指定します。 + +- `--ds_timesteps_2`:このタイムステップまでDeep Shrink深度2を適用します。デフォルトは650です。 + +- `--ds_ratio`:Deep Shrinkでのダウンサンプリングの比率を指定します。デフォルトは0.5です。 + +これらのパラメータはプロンプトオプションでも指定できます: + +- `--dsd1`:プロンプトからDeep Shrink深度1を指定します。 + +- `--dst1`:プロンプトからDeep Shrinkタイムステップ1を指定します。 + +- `--dsd2`:プロンプトからDeep Shrink深度2を指定します。 + +- `--dst2`:プロンプトからDeep Shrinkタイムステップ2を指定します。 + +- `--dsr`:プロンプトからDeep Shrink比率を指定します。 + ## ControlNet 現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。 @@ -346,6 +420,20 @@ python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2 --guide_image_path guide.png --control_net_ratios 1.0 --interactive ``` +## ControlNet-LLLite + +ControlNet-LLLiteは、類似の誘導目的に使用できるControlNetの軽量な代替手段です。 + +以下のオプションがあります: + +- `--control_net_lllite_models`:ControlNet-LLLiteモデルファイルを指定します。 + +- `--control_net_multipliers`:ControlNet-LLLiteの倍率を指定します(重みに類似)。 + +- `--control_net_ratios`:ControlNet-LLLiteを適用するステップの比率を指定します。 + +注意:ControlNetとControlNet-LLLiteは同時に使用できません。 + ## Attention Couple + Reginal LoRA プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。 @@ -450,7 +538,9 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt - `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。 -- `--network_show_meta` : 追加ネットワークのメタデータを表示します。 +- `--shuffle_prompts`:繰り返し時にプロンプトの順序をシャッフルします。`--from_file`で複数のプロンプトを使用する場合に便利です。 + +- `--network_show_meta`:追加ネットワークのメタデータを表示します。 --- @@ -478,6 +568,8 @@ latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py - `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 - `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 - `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 +- `--gradual_latent_s_noise`:Gradual LatentのS_noiseパラメータを指定します。デフォルトは1.0です。 +- `--gradual_latent_unsharp_params`:Gradual Latentのアンシャープマスクパラメータをksize,sigma,strength,target-x形式で指定します(target-x: 1=True, 0=False)。推奨値:`3,0.5,0.5,1`または`3,1.0,1.0,0`。 それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 diff --git a/docs/gen_img_README.md b/docs/gen_img_README.md new file mode 100644 index 000000000..bcfbef7f7 --- /dev/null +++ b/docs/gen_img_README.md @@ -0,0 +1,585 @@ + +This is an inference (image generation) script that supports SD 1.x and 2.x models, LoRA trained with this repository, ControlNet (only v1.0 has been confirmed to work), etc. It is used from the command line. + +# Overview + +* Inference (image generation) script. +* Supports SD 1.x, 2.x (base/v-parameterization), and SDXL models. +* Supports txt2img, img2img, and inpainting. +* Supports interactive mode, prompt reading from files, and continuous generation. +* The number of images generated per prompt line can be specified. +* The total number of repetitions can be specified. +* Supports not only `fp16` but also `bf16`. +* Supports xformers for high-speed generation. + * Although xformers are used for memory-saving generation, it is not as optimized as Automatic 1111's Web UI, so it uses about 6GB of VRAM for 512*512 image generation. +* Extension of prompts to 225 tokens. Supports negative prompts and weighting. +* Supports various samplers from Diffusers including ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle. +* Supports clip skip (uses the output of the nth layer from the end) of Text Encoder. +* Separate loading of VAE. +* Supports CLIP Guided Stable Diffusion, VGG16 Guided Stable Diffusion, Highres. fix, and upscale. + * Highres. fix is an original implementation that has not confirmed the Web UI implementation at all, so the output results may differ. +* LoRA support. Supports application rate specification, simultaneous use of multiple LoRAs, and weight merging. + * It is not possible to specify different application rates for Text Encoder and U-Net. +* Supports Attention Couple. +* Supports ControlNet v1.0. +* Supports Deep Shrink for optimizing generation at different depths. +* Supports Gradual Latent for progressive upscaling during generation. +* Supports CLIP Vision Conditioning for img2img. +* It is not possible to switch models midway, but it can be handled by creating a batch file. +* Various personally desired features have been added. + +Since not all tests are performed when adding features, it is possible that previous features may be affected and some features may not work. Please let us know if you have any problems. + +# Basic Usage + +## Image Generation in Interactive Mode + +Enter as follows: + +```batchfile +python gen_img.py --ckpt --outdir --xformers --fp16 --interactive +``` + +Specify the model (Stable Diffusion checkpoint file or Diffusers model folder) in the `--ckpt` option and the image output destination folder in the `--outdir` option. + +Specify the use of xformers with the `--xformers` option (remove it if you do not use xformers). The `--fp16` option performs inference in fp16 (single precision). For RTX 30 series GPUs, you can also perform inference in bf16 (bfloat16) with the `--bf16` option. + +The `--interactive` option specifies interactive mode. + +If you are using Stable Diffusion 2.0 (or a model with additional training from it), add the `--v2` option. If you are using a model that uses v-parameterization (`768-v-ema.ckpt` and models with additional training from it), add `--v_parameterization` as well. + +If the `--v2` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed. + +When `Type prompt:` is displayed, enter the prompt. + +![image](https://user-images.githubusercontent.com/52813779/235343115-f3b8ac82-456d-4aab-9724-0cc73c4534aa.png) + +*If the image is not displayed and an error occurs, headless (no screen display function) OpenCV may be installed. Install normal OpenCV with `pip install opencv-python`. Alternatively, stop image display with the `--no_preview` option. + +Select the image window and press any key to close the window and enter the next prompt. Press Ctrl+Z and then Enter in the prompt to close the script. + +## Batch Generation of Images with a Single Prompt + +Enter as follows (actually entered on one line): + +```batchfile +python gen_img.py --ckpt --outdir \ + --xformers --fp16 --images_per_prompt --prompt "" +``` + +Specify the number of images to generate per prompt with the `--images_per_prompt` option. Specify the prompt with the `--prompt` option. If it contains spaces, enclose it in double quotes. + +You can specify the batch size with the `--batch_size` option (described later). + +## Batch Generation by Reading Prompts from a File + +Enter as follows: + +```batchfile +python gen_img.py --ckpt --outdir \ + --xformers --fp16 --from_file +``` + +Specify the file containing the prompts with the `--from_file` option. Write one prompt per line. You can specify the number of images to generate per line with the `--images_per_prompt` option. + +## Using Negative Prompts and Weighting + +If you write `--n` in the prompt options (specified like `--x` in the prompt, described later), the following will be a negative prompt. + +Also, weighting with `()` and `[]`, `(xxx:1.3)`, etc., similar to AUTOMATIC1111's Web UI, is possible (the implementation is copied from Diffusers' [Long Prompt Weighting Stable Diffusion](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#long-prompt-weighting-stable-diffusion)). + +It can be specified similarly for prompt specification from the command line and prompt reading from files. + +![image](https://user-images.githubusercontent.com/52813779/235343128-e79cd768-ec59-46f5-8395-fce9bdc46208.png) + +# Main Options + +Specify from the command line. + +## Model Specification + +- `--ckpt `: Specifies the model name. The `--ckpt` option is mandatory. You can specify a Stable Diffusion checkpoint file, a Diffusers model folder, or a Hugging Face model ID. + +- `--v1`: Specify when using Stable Diffusion 1.x series models. This is the default behavior. + +- `--v2`: Specify when using Stable Diffusion 2.x series models. Not required for 1.x series. + +- `--sdxl`: Specify when using Stable Diffusion XL models. + +- `--v_parameterization`: Specify when using models that use v-parameterization (`768-v-ema.ckpt` and models with additional training from it, Waifu Diffusion v1.5, etc.). + + If the `--v2` or `--sdxl` specification is incorrect, an error will occur when loading the model. If the `--v_parameterization` specification is incorrect, a brown image will be displayed. + +- `--vae`: Specifies the VAE to use. If not specified, the VAE in the model will be used. + +- `--tokenizer_cache_dir`: Specifies the cache directory for the tokenizer (for offline usage). + +## Image Generation and Output + +- `--interactive`: Operates in interactive mode. Images are generated when prompts are entered. + +- `--prompt `: Specifies the prompt. If it contains spaces, enclose it in double quotes. + +- `--from_file `: Specifies the file containing the prompts. Write one prompt per line. Image size and guidance scale can be specified with prompt options (described later). + +- `--from_module `: Loads prompts from a Python module. The module should implement a `get_prompter(args, pipe, networks)` function. + +- `--prompter_module_args`: Specifies additional arguments to pass to the prompter module. + +- `--W `: Specifies the width of the image. The default is `512`. + +- `--H `: Specifies the height of the image. The default is `512`. + +- `--steps `: Specifies the number of sampling steps. The default is `50`. + +- `--scale `: Specifies the unconditional guidance scale. The default is `7.5`. + +- `--sampler `: Specifies the sampler. The default is `ddim`. The following samplers are supported: ddim, pndm, lms, euler, euler_a, heun, dpm_2, dpm_2_a, dpmsolver, dpmsolver++, dpmsingle. Some can also be specified with k_ prefix (k_lms, k_euler, k_euler_a, k_dpm_2, k_dpm_2_a). + +- `--outdir `: Specifies the output destination for images. + +- `--images_per_prompt `: Specifies the number of images to generate per prompt. The default is `1`. + +- `--clip_skip `: Specifies which layer from the end of CLIP to use. If omitted, the last layer is used. + +- `--max_embeddings_multiples `: Specifies how many times the CLIP input/output length should be multiplied by the default (75). If not specified, it remains 75. For example, specifying 3 makes the input/output length 225. + +- `--negative_scale`: Specifies the guidance scale for unconditioning individually. Implemented with reference to [this article by gcem156](https://note.com/gcem156/n/ne9a53e4a6f43). + +- `--emb_normalize_mode`: Specifies the embedding normalization mode. Options are "original" (default), "abs", and "none". This affects how prompt weights are normalized. + +## SDXL-Specific Options + +When using SDXL models (with `--sdxl` flag), additional conditioning options are available: + +- `--original_height`: Specifies the original height for SDXL conditioning. This affects the model's understanding of the target resolution. + +- `--original_width`: Specifies the original width for SDXL conditioning. This affects the model's understanding of the target resolution. + +- `--original_height_negative`: Specifies the original height for SDXL negative conditioning. + +- `--original_width_negative`: Specifies the original width for SDXL negative conditioning. + +- `--crop_top`: Specifies the crop top offset for SDXL conditioning. + +- `--crop_left`: Specifies the crop left offset for SDXL conditioning. + +## Adjusting Memory Usage and Generation Speed + +- `--batch_size `: Specifies the batch size. The default is `1`. A larger batch size consumes more memory but speeds up generation. + +- `--vae_batch_size `: Specifies the VAE batch size. The default is the same as the batch size. + Since VAE consumes more memory, memory shortages may occur after denoising (after the step reaches 100%). In such cases, reduce the VAE batch size. + +- `--vae_slices `: Splits the image into slices for VAE processing to reduce VRAM usage. None (default) for no splitting. Values like 16 or 32 are recommended. Enabling this is slower but uses less VRAM. + +- `--no_half_vae`: Prevents using fp16/bf16 precision for VAE processing. Uses fp32 instead. Use this if you encounter VAE-related issues or artifacts. + +- `--xformers`: Specify when using xformers. + +- `--sdpa`: Use scaled dot-product attention in PyTorch 2 for optimization. + +- `--diffusers_xformers`: Use xformers via Diffusers (note: incompatible with Hypernetworks). + +- `--fp16`: Performs inference in fp16 (single precision). If neither `fp16` nor `bf16` is specified, inference is performed in fp32 (single precision). + +- `--bf16`: Performs inference in bf16 (bfloat16). Can only be specified for RTX 30 series GPUs. The `--bf16` option will cause an error on GPUs other than the RTX 30 series. It seems that `bf16` is less likely to result in NaN (black image) inference results than `fp16`. + +## Using Additional Networks (LoRA, etc.) + +- `--network_module`: Specifies the additional network to use. For LoRA, specify `--network_module networks.lora`. To use multiple LoRAs, specify like `--network_module networks.lora networks.lora networks.lora`. + +- `--network_weights`: Specifies the weight file of the additional network to use. Specify like `--network_weights model.safetensors`. To use multiple LoRAs, specify like `--network_weights model1.safetensors model2.safetensors model3.safetensors`. The number of arguments should be the same as the number specified in `--network_module`. + +- `--network_mul`: Specifies how many times to multiply the weight of the additional network to use. The default is `1`. Specify like `--network_mul 0.8`. To use multiple LoRAs, specify like `--network_mul 0.4 0.5 0.7`. The number of arguments should be the same as the number specified in `--network_module`. + +- `--network_merge`: Merges the weights of the additional networks to be used in advance with the weights specified in `--network_mul`. Cannot be used simultaneously with `--network_pre_calc`. The prompt option `--am` and Regional LoRA can no longer be used, but generation will be accelerated to the same extent as when LoRA is not used. + +- `--network_pre_calc`: Calculates the weights of the additional network to be used in advance for each generation. The prompt option `--am` can be used. Generation is accelerated to the same extent as when LoRA is not used, but time is required to calculate the weights before generation, and memory usage also increases slightly. It is disabled when Regional LoRA is used. + +- `--network_regional_mask_max_color_codes`: Specifies the maximum number of color codes to use for regional masks. If not specified, masks are applied by channel. Used with Regional LoRA to control the number of regions that can be defined by colors in the mask. + +- `--network_args`: Specifies additional arguments to pass to the network module in key=value format. For example: `--network_args "alpha=1.0,dropout=0.1"`. + +- `--network_merge_n_models`: When using network merging, specifies the number of models to merge (instead of merging all loaded networks). + +# Examples of Main Option Specifications + +The following is an example of batch generating 64 images with the same prompt and a batch size of 4. + +```batchfile +python gen_img.py --ckpt model.ckpt --outdir outputs \ + --xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a \ + --steps 32 --batch_size 4 --images_per_prompt 64 \ + --prompt "beautiful flowers --n monochrome" +``` + +The following is an example of batch generating 10 images each for prompts written in a file, with a batch size of 4. + +```batchfile +python gen_img.py --ckpt model.ckpt --outdir outputs \ + --xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a \ + --steps 32 --batch_size 4 --images_per_prompt 10 \ + --from_file prompts.txt +``` + +Example of using Textual Inversion (described later) and LoRA. + +```batchfile +python gen_img.py --ckpt model.safetensors \ + --scale 8 --steps 48 --outdir txt2img --xformers \ + --W 512 --H 768 --fp16 --sampler k_euler_a \ + --textual_inversion_embeddings goodembed.safetensors negprompt.pt \ + --network_module networks.lora networks.lora \ + --network_weights model1.safetensors model2.safetensors \ + --network_mul 0.4 0.8 \ + --clip_skip 2 --max_embeddings_multiples 1 \ + --batch_size 8 --images_per_prompt 1 --interactive +``` + +# Prompt Options + +In the prompt, you can specify various options from the prompt with "two hyphens + n alphabetic characters" like `--n`. It is valid whether specifying the prompt from interactive mode, command line, or file. + +Please put spaces before and after the prompt option specification `--n`. + +- `--n`: Specifies a negative prompt. + +- `--w`: Specifies the image width. Overrides the command line specification. + +- `--h`: Specifies the image height. Overrides the command line specification. + +- `--s`: Specifies the number of steps. Overrides the command line specification. + +- `--d`: Specifies the random seed for this image. If `--images_per_prompt` is specified, specify multiple seeds separated by commas, like "--d 1,2,3,4". + *For various reasons, the generated image may differ from the Web UI even with the same random seed. + +- `--l`: Specifies the guidance scale. Overrides the command line specification. + +- `--t`: Specifies the strength of img2img (described later). Overrides the command line specification. + +- `--nl`: Specifies the guidance scale for negative prompts (described later). Overrides the command line specification. + +- `--am`: Specifies the weight of the additional network. Overrides the command line specification. If using multiple additional networks, specify them separated by __commas__, like `--am 0.8,0.5,0.3`. + +- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification. + +- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification. + +- `--gls`: Specifies the ratio to increase the size of the latent for Gradual Latent. Overrides the command line specification. + +- `--gle`: Specifies the interval to increase the size of the latent for Gradual Latent. Overrides the command line specification. + +*Specifying these options may cause the batch to be executed with a size smaller than the batch size (because they cannot be generated collectively if these values are different). (You don't have to worry too much, but when reading prompts from a file and generating, arranging prompts with the same values for these options will improve efficiency.) + +Example: +``` +(masterpiece, best quality), 1girl, in shirt and plated skirt, standing at street under cherry blossoms, upper body, [from below], kind smile, looking at another, [goodembed] --n realistic, real life, (negprompt), (lowres:1.1), (worst quality:1.2), (low quality:1.1), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, normal quality, jpeg artifacts, signature, watermark, username, blurry --w 960 --h 640 --s 28 --d 1 +``` + +![image](https://user-images.githubusercontent.com/52813779/235343446-25654172-fff4-4aaf-977a-20d262b51676.png) + +# img2img + +## Options + +- `--image_path`: Specifies the image to use for img2img. Specify like `--image_path template.png`. If a folder is specified, images in that folder will be used sequentially. + +- `--strength`: Specifies the strength of img2img. Specify like `--strength 0.8`. The default is `0.8`. + +- `--sequential_file_name`: Specifies whether to make file names sequential. If specified, the generated file names will be sequential starting from `im_000001.png`. + +- `--use_original_file_name`: If specified, the generated file name will be prepended with the original file name (for img2img mode). + +- `--clip_vision_strength`: Enables CLIP Vision Conditioning for img2img with the specified strength. Uses the CLIP Vision model to enhance conditioning from the input image. + +## Command Line Execution Example + +```batchfile +python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt \ + --outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32 \ + --image_path template.png --strength 0.8 \ + --prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes, \ + sailor school uniform, outdoors \ + --n lowres, bad anatomy, bad hands, error, missing fingers, cropped, \ + worst quality, low quality, normal quality, jpeg artifacts, (blurry), \ + hair ornament, glasses" \ + --batch_size 8 --images_per_prompt 32 +``` + +If a folder is specified in the `--image_path` option, images in that folder will be read sequentially. The number of images generated will be the number of prompts, not the number of images, so please match the number of images to img2img and the number of prompts by specifying the `--images_per_prompt` option. + +Files are read sorted by file name. Note that the sort order is string order (not `1.jpg -> 2.jpg -> 10.jpg` but `1.jpg -> 10.jpg -> 2.jpg`), so please pad the beginning with zeros (e.g., `01.jpg -> 02.jpg -> 10.jpg`). + +## Upscale using img2img + +If you specify the generated image size with the `--W` and `--H` command line options during img2img, the original image will be resized to that size before img2img. + +Also, if the original image for img2img was generated by this script, omitting the prompt will retrieve the prompt from the original image's metadata and use it as is. This allows you to perform only the 2nd stage operation of Highres. fix. + +## Inpainting during img2img + +You can specify an image and a mask image for inpainting (inpainting models are not supported, it simply performs img2img on the mask area). + +The options are as follows: + +- `--mask_image`: Specifies the mask image. Similar to `--img_path`, if a folder is specified, images in that folder will be used sequentially. + +The mask image is a grayscale image, and the white parts will be inpainted. It is recommended to gradient the boundaries to make it somewhat smooth. + +![image](https://user-images.githubusercontent.com/52813779/235343795-9eaa6d98-02ff-4f32-b089-80d1fc482453.png) + +# Other Features + +## Textual Inversion + +Specify the embeddings to use with the `--textual_inversion_embeddings` option (multiple specifications possible). By using the file name without the extension in the prompt, that embedding will be used (same usage as Web UI). It can also be used in negative prompts. + +As models, you can use Textual Inversion models trained with this repository and Textual Inversion models trained with Web UI (image embedding is not supported). + +## Extended Textual Inversion + +Specify the `--XTI_embeddings` option instead of `--textual_inversion_embeddings`. Usage is the same as `--textual_inversion_embeddings`. + +## Highres. fix + +This is a similar feature to the one in AUTOMATIC1111's Web UI (it may differ in various ways as it is an original implementation). It first generates a smaller image and then uses that image as a base for img2img to generate a large resolution image while preventing the entire image from collapsing. + +The number of steps for the 2nd stage is calculated from the values of the `--steps` and `--strength` options (`steps*strength`). + +Cannot be used with img2img. + +The following options are available: + +- `--highres_fix_scale`: Enables Highres. fix and specifies the size of the image generated in the 1st stage as a magnification. If the final output is 1024x1024 and you want to generate a 512x512 image first, specify like `--highres_fix_scale 0.5`. Please note that this is the reciprocal of the specification in Web UI. + +- `--highres_fix_steps`: Specifies the number of steps for the 1st stage image. The default is `28`. + +- `--highres_fix_save_1st`: Specifies whether to save the 1st stage image. + +- `--highres_fix_latents_upscaling`: If specified, the 1st stage image will be upscaled on a latent basis during 2nd stage image generation (only bilinear is supported). If not specified, the image will be upscaled with LANCZOS4. + +- `--highres_fix_upscaler`: Uses an arbitrary upscaler for the 2nd stage. Currently, only `--highres_fix_upscaler tools.latent_upscaler` is supported. + +- `--highres_fix_upscaler_args`: Specifies the arguments to pass to the upscaler specified with `--highres_fix_upscaler`. + For `tools.latent_upscaler`, specify the weight file like `--highres_fix_upscaler_args "weights=D:\\Work\\SD\\Models\\others\\etc\\upscaler-v1-e100-220.safetensors"`. + +- `--highres_fix_disable_control_net`: Disables ControlNet for the 2nd stage of Highres fix. By default, ControlNet is used in both stages. + +Command line example: + +```batchfile +python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt\ + --n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img \ + --steps 48 --sampler ddim --fp16 \ + --xformers \ + --images_per_prompt 1 --interactive \ + --highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5 +``` + +## Deep Shrink + +Deep Shrink is a technique that optimizes the generation process by using different depths of the UNet at different timesteps. It can improve generation quality and efficiency. + +The following options are available: + +- `--ds_depth_1`: Enables Deep Shrink with this depth for the first phase. Valid values are 0 to 8. + +- `--ds_timesteps_1`: Applies Deep Shrink depth 1 until this timestep. Default is 650. + +- `--ds_depth_2`: Specifies the depth for the second phase of Deep Shrink. + +- `--ds_timesteps_2`: Applies Deep Shrink depth 2 until this timestep. Default is 650. + +- `--ds_ratio`: Specifies the ratio for downsampling in Deep Shrink. Default is 0.5. + +These parameters can also be specified through prompt options: + +- `--dsd1`: Specifies Deep Shrink depth 1 from the prompt. + +- `--dst1`: Specifies Deep Shrink timestep 1 from the prompt. + +- `--dsd2`: Specifies Deep Shrink depth 2 from the prompt. + +- `--dst2`: Specifies Deep Shrink timestep 2 from the prompt. + +- `--dsr`: Specifies Deep Shrink ratio from the prompt. + +*Additional prompt options for Gradual Latent (requires `euler_a` sampler):* + +- `--glt`: Specifies the timestep to start increasing the size of the latent for Gradual Latent. Overrides the command line specification. + +- `--glr`: Specifies the initial size of the latent for Gradual Latent as a ratio. Overrides the command line specification. + +- `--gls`: Specifies the ratio to increase the size of the latent for Gradual Latent. Overrides the command line specification. + +- `--gle`: Specifies the interval to increase the size of the latent for Gradual Latent. Overrides the command line specification. + +## ControlNet + +Currently, only ControlNet 1.0 has been confirmed to work. Only Canny is supported for preprocessing. + +The following options are available: + +- `--control_net_models`: Specifies the ControlNet model file. + If multiple are specified, they will be switched and used for each step (differs from the implementation of the ControlNet extension in Web UI). Supports both diff and normal. + +- `--guide_image_path`: Specifies the hint image to use for ControlNet. Similar to `--img_path`, if a folder is specified, images in that folder will be used sequentially. For models other than Canny, please perform preprocessing beforehand. + +- `--control_net_preps`: Specifies the preprocessing for ControlNet. Multiple specifications are possible, similar to `--control_net_models`. Currently, only canny is supported. If preprocessing is not used for the target model, specify `none`. + For canny, you can specify thresholds 1 and 2 separated by `_`, like `--control_net_preps canny_63_191`. + +- `--control_net_weights`: Specifies the weight when applying ControlNet (`1.0` for normal, `0.5` for half influence). Multiple specifications are possible, similar to `--control_net_models`. + +- `--control_net_ratios`: Specifies the range of steps to apply ControlNet. If `0.5`, ControlNet is applied up to half the number of steps. Multiple specifications are possible, similar to `--control_net_models`. + +Command line example: + +```batchfile +python gen_img.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers \ + --W 512 --H 768 --bf16 --sampler k_euler_a \ + --control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0 \ + --guide_image_path guide.png --control_net_ratios 1.0 --interactive +``` + +## ControlNet-LLLite + +ControlNet-LLLite is a lightweight alternative to ControlNet that can be used for similar guidance purposes. + +The following options are available: + +- `--control_net_lllite_models`: Specifies the ControlNet-LLLite model files. + +- `--control_net_multipliers`: Specifies the multiplier for ControlNet-LLLite (similar to weights). + +- `--control_net_ratios`: Specifies the ratio of steps to apply ControlNet-LLLite. + +Note that ControlNet and ControlNet-LLLite cannot be used at the same time. + +## Attention Couple + Regional LoRA + +This is a feature that allows you to divide the prompt into several parts and specify which region in the image each prompt should be applied to. There are no individual options, but it is specified with `mask_path` and the prompt. + +First, define multiple parts using ` AND ` in the prompt. Region specification can be done for the first three parts, and subsequent parts are applied to the entire image. Negative prompts are applied to the entire image. + +In the following, three parts are defined with AND. + +``` +shs 2girls, looking at viewer, smile AND bsb 2girls, looking back AND 2girls --n bad quality, worst quality +``` + +Next, prepare a mask image. The mask image is a color image, and each RGB channel corresponds to the part separated by AND in the prompt. Also, if the value of a certain channel is all 0, it is applied to the entire image. + +In the example above, the R channel corresponds to `shs 2girls, looking at viewer, smile`, the G channel to `bsb 2girls, looking back`, and the B channel to `2girls`. If you use a mask image like the following, since there is no specification for the B channel, `2girls` will be applied to the entire image. + +![image](https://user-images.githubusercontent.com/52813779/235343061-b4dc9392-3dae-4831-8347-1e9ae5054251.png) + +The mask image is specified with `--mask_path`. Currently, only one image is supported. It is automatically resized and applied to the specified image size. + +It can also be combined with ControlNet (combination with ControlNet is recommended for detailed position specification). + +If LoRA is specified, multiple LoRAs specified with `--network_weights` will correspond to each part of AND. As a current constraint, the number of LoRAs must be the same as the number of AND parts. + +## CLIP Guided Stable Diffusion + +The source code is copied and modified from [this custom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion) in Diffusers' Community Examples. + +In addition to the normal prompt-based generation specification, it additionally acquires the text features of the prompt with a larger CLIP and controls the generated image so that the features of the image being generated approach those text features (this is my rough understanding). Since a larger CLIP is used, VRAM usage increases considerably (it may be difficult even for 512*512 with 8GB of VRAM), and generation time also increases. + +Note that the selectable samplers are DDIM, PNDM, and LMS only. + +Specify how much to reflect the CLIP features numerically with the `--clip_guidance_scale` option. In the previous sample, it is 100, so it seems good to start around there and increase or decrease it. + +By default, the first 75 tokens of the prompt (excluding special weighting characters) are passed to CLIP. With the `--c` option in the prompt, you can specify the text to be passed to CLIP separately from the normal prompt (for example, it is thought that CLIP cannot recognize DreamBooth identifiers or model-specific words like "1girl", so text excluding them is considered good). + +Command line example: + +```batchfile +python gen_img.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1 \ + --scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36 \ + --sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1 \ + --interactive --clip_guidance_scale 100 +``` + +## CLIP Image Guided Stable Diffusion + +This is a feature that passes another image to CLIP instead of text and controls generation to approach its features. Specify the numerical value of the application amount with the `--clip_image_guidance_scale` option and the image (file or folder) to use for guidance with the `--guide_image_path` option. + +Command line example: + +```batchfile +python gen_img.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt\ + --n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img \ + --steps 80 --sampler ddim --fp16 --opt_channels_last --xformers \ + --images_per_prompt 1 --interactive --clip_image_guidance_scale 100 \ + --guide_image_path YUKA160113420I9A4104_TP_V.jpg +``` + +### VGG16 Guided Stable Diffusion + +This is a feature that generates images to approach a specified image. In addition to the normal prompt-based generation specification, it additionally acquires the features of VGG16 and controls the generated image so that the image being generated approaches the specified guide image. It is recommended to use it with img2img (images tend to be blurred in normal generation). This is an original feature that reuses the mechanism of CLIP Guided Stable Diffusion. The idea is also borrowed from style transfer using VGG. + +Note that the selectable samplers are DDIM, PNDM, and LMS only. + +Specify how much to reflect the VGG16 features numerically with the `--vgg16_guidance_scale` option. From what I've tried, it seems good to start around 100 and increase or decrease it. Specify the image (file or folder) to use for guidance with the `--guide_image_path` option. + +When batch converting multiple images with img2img and using the original images as guide images, it is OK to specify the same value for `--guide_image_path` and `--image_path`. + +Command line example: + +```batchfile +python gen_img.py --ckpt wd-v1-3-full-pruned-half.ckpt \ + --n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img \ + --xformers --sampler ddim --fp16 --W 512 --H 704 \ + --batch_size 1 --images_per_prompt 1 \ + --prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face \ + --n lowres, bad anatomy, bad hands, error, missing fingers, \ + cropped, worst quality, low quality, normal quality, \ + jpeg artifacts, blurry, 3d, bad face, monochrome --d 1" \ + --strength 0.8 --image_path ..\\src_image\ + --vgg16_guidance_scale 100 --guide_image_path ..\\src_image \ +``` + +You can specify the VGG16 layer number used for feature acquisition with `--vgg16_guidance_layerP` (default is 20, which is ReLU of conv4-2). It is said that upper layers express style and lower layers express content. + +![image](https://user-images.githubusercontent.com/52813779/235343813-3c1f0d7a-4fb3-4274-98e4-b92d76b551df.png) + +# Other Options + +- `--no_preview`: Does not display preview images in interactive mode. Specify this if OpenCV is not installed or if you want to check the output files directly. + +- `--n_iter`: Specifies the number of times to repeat generation. The default is 1. Specify this when you want to perform generation multiple times when reading prompts from a file. + +- `--tokenizer_cache_dir`: Specifies the cache directory for the tokenizer. (Work in progress) + +- `--seed`: Specifies the random seed. When generating one image, it is the seed for that image. When generating multiple images, it is the seed for the random numbers used to generate the seeds for each image (when generating multiple images with `--from_file`, specifying the `--seed` option will make each image have the same seed when executed multiple times). + +- `--iter_same_seed`: When there is no random seed specification in the prompt, the same seed is used for all repetitions of `--n_iter`. Used to unify and compare seeds between multiple prompts specified with `--from_file`. + +- `--shuffle_prompts`: Shuffles the order of prompts in iteration. Useful when using `--from_file` with multiple prompts. + +- `--diffusers_xformers`: Uses Diffuser's xformers. + +- `--opt_channels_last`: Arranges tensor channels last during inference. May speed up in some cases. + +- `--network_show_meta`: Displays the metadata of the additional network. + + +--- + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img.py` have the following options. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. +- `--gradual_latent_s_noise`: Specifies the s_noise parameter for Gradual Latent. Default is 1.0. +- `--gradual_latent_unsharp_params`: Specifies unsharp mask parameters for Gradual Latent in the format: ksize,sigma,strength,target-x (where target-x: 1=True, 0=False). Recommended values: `3,0.5,0.5,1` or `3,1.0,1.0,0`. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. + +It is more effective with SD 1.5. It is quite subtle with SDXL. diff --git a/docs/hunyuan_image_train_network.md b/docs/hunyuan_image_train_network.md new file mode 100644 index 000000000..b0e9cdd98 --- /dev/null +++ b/docs/hunyuan_image_train_network.md @@ -0,0 +1,525 @@ +Status: reviewed + +# LoRA Training Guide for HunyuanImage-2.1 using `hunyuan_image_train_network.py` / `hunyuan_image_train_network.py` を用いたHunyuanImage-2.1モデルのLoRA学習ガイド + +This document explains how to train LoRA models for the HunyuanImage-2.1 model using `hunyuan_image_train_network.py` included in the `sd-scripts` repository. + +
+日本語 + +このドキュメントでは、`sd-scripts`リポジトリに含まれる`hunyuan_image_train_network.py`を使用して、HunyuanImage-2.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 + +
+ +## 1. Introduction / はじめに + +`hunyuan_image_train_network.py` trains additional networks such as LoRA on the HunyuanImage-2.1 model, which uses a transformer-based architecture (DiT) different from Stable Diffusion. Two text encoders, Qwen2.5-VL and byT5, and a dedicated VAE are used. + +This guide assumes you know the basics of LoRA training. For common options see [train_network.py](train_network.md) and [sdxl_train_network.py](sdxl_train_network.md). + +**Prerequisites:** + +* The repository is cloned and the Python environment is ready. +* A training dataset is prepared. See the dataset configuration guide. + +
+日本語 + +`hunyuan_image_train_network.py`はHunyuanImage-2.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。HunyuanImage-2.1はStable Diffusionとは異なるDiT (Diffusion Transformer) アーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](config_README-ja.md)を参照してください) + +
+ +## 2. Differences from `train_network.py` / `train_network.py` との違い + +`hunyuan_image_train_network.py` is based on `train_network.py` but adapted for HunyuanImage-2.1. Main differences include: + +* **Target model:** HunyuanImage-2.1 model. +* **Model structure:** HunyuanImage-2.1 uses a Transformer-based architecture (DiT). It uses two text encoders (Qwen2.5-VL and byT5) and a dedicated VAE. +* **Required arguments:** Additional arguments for the DiT model, Qwen2.5-VL, byT5, and VAE model files. +* **Incompatible options:** Some Stable Diffusion-specific arguments (e.g., `--v2`, `--clip_skip`, `--max_token_length`) are not used. +* **HunyuanImage-2.1-specific arguments:** Additional arguments for specific training parameters like flow matching. + +
+日本語 + +`hunyuan_image_train_network.py`は`train_network.py`をベースに、HunyuanImage-2.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** HunyuanImage-2.1モデルを対象とします。 +* **モデル構造:** HunyuanImage-2.1はDiTベースのアーキテクチャを持ちます。Text EncoderとしてQwen2.5-VLとbyT5の二つを使用し、専用のVAEを使用します。 +* **必須の引数:** DiTモデル、Qwen2.5-VL、byT5、VAEの各モデルファイルを指定する引数が追加されています。 +* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)は使用されません。 +* **HunyuanImage-2.1特有の引数:** Flow Matchingなど、特有の学習パラメータを指定する引数が追加されています。 + +
+ +## 3. Preparation / 準備 + +Before starting training you need: + +1. **Training script:** `hunyuan_image_train_network.py` +2. **HunyuanImage-2.1 DiT model file:** Base DiT model `.safetensors` file. +3. **Text Encoder model files:** + - Qwen2.5-VL model file (`--text_encoder`). + - byT5 model file (`--byt5`). +4. **VAE model file:** HunyuanImage-2.1-compatible VAE model `.safetensors` file (`--vae`). +5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration. + +### Downloading Required Models + +To train HunyuanImage-2.1 models, you need to download the following model files: + +- **DiT Model**: Download from the [Tencent HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1/) repository. Use `dit/hunyuanimage2.1.safetensors`. +- **Text Encoders and VAE**: Download from the [Comfy-Org/HunyuanImage_2.1_ComfyUI](https://huggingface.co/Comfy-Org/HunyuanImage_2.1_ComfyUI) repository: + - Qwen2.5-VL: `split_files/text_encoders/qwen_2.5_vl_7b.safetensors` + - byT5: `split_files/text_encoders/byt5_small_glyphxl_fp16.safetensors` + - VAE: `split_files/vae/hunyuan_image_2.1_vae_fp16.safetensors` + +
+日本語 + +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `hunyuan_image_train_network.py` +2. **HunyuanImage-2.1 DiTモデルファイル:** 学習のベースとなるDiTモデルの`.safetensors`ファイル。 +3. **Text Encoderモデルファイル:** + - Qwen2.5-VLモデルファイル (`--text_encoder`)。 + - byT5モデルファイル (`--byt5`)。 +4. **VAEモデルファイル:** HunyuanImage-2.1に対応するVAEモデルの`.safetensors`ファイル (`--vae`)。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](config_README-ja.md)を参照してください)。 + +**必要なモデルのダウンロード** + +HunyuanImage-2.1モデルを学習するためには、以下のモデルファイルをダウンロードする必要があります: + +- **DiTモデル**: [Tencent HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1/) リポジトリから `dit/hunyuanimage2.1.safetensors` をダウンロードします。 +- **Text EncoderとVAE**: [Comfy-Org/HunyuanImage_2.1_ComfyUI](https://huggingface.co/Comfy-Org/HunyuanImage_2.1_ComfyUI) リポジトリから以下をダウンロードします: + - Qwen2.5-VL: `split_files/text_encoders/qwen_2.5_vl_7b.safetensors` + - byT5: `split_files/text_encoders/byt5_small_glyphxl_fp16.safetensors` + - VAE: `split_files/vae/hunyuan_image_2.1_vae_fp16.safetensors` + +
+ +## 4. Running the Training / 学習の実行 + +Run `hunyuan_image_train_network.py` from the terminal with HunyuanImage-2.1 specific arguments. Here's a basic command example: + +```bash +accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py \ + --pretrained_model_name_or_path="" \ + --text_encoder="" \ + --byt5="" \ + --vae="" \ + --dataset_config="my_hunyuan_dataset_config.toml" \ + --output_dir="" \ + --output_name="my_hunyuan_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_hunyuan_image \ + --network_dim=16 \ + --network_alpha=1 \ + --network_train_unet_only \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW8bit" \ + --lr_scheduler="constant" \ + --attn_mode="torch" \ + --split_attn \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --model_prediction_type="raw" \ + --discrete_flow_shift=5.0 \ + --blocks_to_swap=18 \ + --cache_text_encoder_outputs \ + --cache_latents +``` + +**HunyuanImage-2.1 training does not support LoRA modules for Text Encoders, so `--network_train_unet_only` is required.** + +
+日本語 + +学習は、ターミナルから`hunyuan_image_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、HunyuanImage-2.1特有の引数を指定する必要があります。 + +コマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 + +The script adds HunyuanImage-2.1 specific arguments. For common arguments (like `--output_dir`, `--output_name`, `--network_module`, etc.), see the [`train_network.py` guide](train_network.md). + +#### Model-related [Required] + +* `--pretrained_model_name_or_path=""` **[Required]** + - Specifies the path to the base DiT model `.safetensors` file. +* `--text_encoder=""` **[Required]** + - Specifies the path to the Qwen2.5-VL Text Encoder model file. Should be `bfloat16`. +* `--byt5=""` **[Required]** + - Specifies the path to the byT5 Text Encoder model file. Should be `float16`. +* `--vae=""` **[Required]** + - Specifies the path to the HunyuanImage-2.1-compatible VAE model `.safetensors` file. + +#### HunyuanImage-2.1 Training Parameters + +* `--network_train_unet_only` **[Required]** + - Specifies that only the DiT model will be trained. LoRA modules for Text Encoders are not supported. +* `--discrete_flow_shift=` + - Specifies the shift value for the scheduler used in Flow Matching. Default is `5.0`. +* `--model_prediction_type=` + - Specifies what the model predicts. Choose from `raw`, `additive`, `sigma_scaled`. Default and recommended is `raw`. +* `--timestep_sampling=` + - Specifies the sampling method for timesteps (noise levels) during training. Choose from `sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`. Default is `sigma`. +* `--sigmoid_scale=` + - Scale factor when `timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default is `1.0`. + +#### Memory/Speed Related + +* `--attn_mode=` + - Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1. +* `--split_attn` + - Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1. +* `--fp8_scaled` + - Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option. See [Musubi Tuner's documentation](https://github.com/kohya-ss/musubi-tuner/blob/main/docs/advanced_config.md#fp8-weight-optimization-for-models--%E3%83%A2%E3%83%87%E3%83%AB%E3%81%AE%E9%87%8D%E3%81%BF%E3%81%AEfp8%E3%81%B8%E3%81%AE%E6%9C%80%E9%81%A9%E5%8C%96) for details. +* `--fp8_vl` + - Use FP8 for the VLM (Qwen2.5-VL) text encoder. +* `--text_encoder_cpu` + - Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.** +* `--blocks_to_swap=` **[Experimental Feature]** + - Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`. +* `--cache_text_encoder_outputs` + - Caches the outputs of Qwen2.5-VL and byT5. This reduces memory usage. +* `--cache_latents`, `--cache_latents_to_disk` + - Caches the outputs of VAE. Similar functionality to [sdxl_train_network.py](sdxl_train_network.md). +* `--vae_chunk_size=` + - Enables chunked processing in the VAE to reduce VRAM usage during encoding and decoding. Specify the chunk size as an integer (e.g., `16`). Larger values use more VRAM but are faster. Default is `None` (no chunking). This option is useful when VRAM is limited (e.g., 8GB or 12GB). + +
+日本語 + +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のHunyuanImage-2.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。 + +コマンドラインの例と詳細な引数の説明は英語のドキュメントを参照してください。 + +
+ +## 5. Using the Trained Model / 学習済みモデルの利用 + +After training, a LoRA model file is saved in `output_dir` and can be used in inference environments supporting HunyuanImage-2.1. + +
+日本語 + +学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_hunyuan_lora.safetensors`)が保存されます。このファイルは、HunyuanImage-2.1モデルに対応した推論環境で使用できます。 + +
+ +## 6. Advanced Settings / 高度な設定 + +### 6.1. VRAM Usage Optimization / VRAM使用量の最適化 + +HunyuanImage-2.1 is a large model, so GPUs without sufficient VRAM require optimization. + +#### Recommended Settings by GPU Memory + +Based on testing with the pull request, here are recommended VRAM optimization settings: + +| GPU Memory | Recommended Settings | +|------------|---------------------| +| 40GB+ VRAM | Standard settings (no special optimization needed) | +| 24GB VRAM | `--fp8_scaled --blocks_to_swap 9` | +| 12GB VRAM | `--fp8_scaled --blocks_to_swap 32` | +| 8GB VRAM | `--fp8_scaled --blocks_to_swap 37` | + +#### Key VRAM Reduction Options + +- **`--fp8_scaled`**: Enables training the DiT in scaled FP8 format. This is the recommended FP8 option for HunyuanImage-2.1, replacing the unsupported `--fp8_base` option. Essential for <40GB VRAM environments. +- **`--fp8_vl`**: Use FP8 for the VLM (Qwen2.5-VL) text encoder. +- **`--blocks_to_swap `**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. Up to 37 blocks can be swapped for HunyuanImage-2.1. +- **`--cpu_offload_checkpointing`**: Offloads gradient checkpoints to CPU. Can reduce VRAM usage but decreases training speed. Cannot be used with `--blocks_to_swap`. +- **Using Adafactor optimizer**: Can reduce VRAM usage more than 8bit AdamW: + ``` + --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 + ``` + +
+日本語 + +HunyuanImage-2.1は大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。 + +#### GPU別推奨設定 + +Pull Requestのテスト結果に基づく推奨VRAM最適化設定: + +| GPU Memory | 推奨設定 | +|------------|---------| +| 40GB+ VRAM | 標準設定(特別な最適化不要) | +| 24GB VRAM | `--fp8_scaled --blocks_to_swap 9` | +| 12GB VRAM | `--fp8_scaled --blocks_to_swap 32` | +| 8GB VRAM | `--fp8_scaled --blocks_to_swap 37` | + +主要なVRAM削減オプション: +- `--fp8_scaled`: DiTをスケールされたFP8形式で学習(推奨されるFP8オプション、40GB VRAM未満の環境では必須) +- `--fp8_vl`: VLMテキストエンコーダにFP8を使用 +- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ(最大37ブロック) +- `--cpu_offload_checkpointing`: 勾配チェックポイントをCPUにオフロード +- Adafactorオプティマイザの使用 + +
+ +### 6.2. Important HunyuanImage-2.1 LoRA Training Settings / HunyuanImage-2.1 LoRA学習の重要な設定 + +HunyuanImage-2.1 training has several settings that can be specified with arguments: + +#### Timestep Sampling Methods + +The `--timestep_sampling` option specifies how timesteps (0-1) are sampled: + +- `sigma`: Sigma-based like SD3 (Default) +- `uniform`: Uniform random +- `sigmoid`: Sigmoid of normal distribution random +- `shift`: Sigmoid value of normal distribution random with shift. +- `flux_shift`: Shift sigmoid value of normal distribution random according to resolution. + +#### Model Prediction Processing + +The `--model_prediction_type` option specifies how to interpret and process model predictions: + +- `raw`: Use as-is **[Recommended, Default]** +- `additive`: Add to noise input +- `sigma_scaled`: Apply sigma scaling + +#### Recommended Settings + +Based on experiments, the default settings work well: +``` +--model_prediction_type raw --discrete_flow_shift 5.0 +``` + +
+日本語 + +HunyuanImage-2.1の学習には、引数で指定できるいくつかの設定があります。詳細な説明とコマンドラインの例は英語のドキュメントを参照してください。 + +主要な設定オプション: +- タイムステップのサンプリング方法(`--timestep_sampling`) +- モデル予測の処理方法(`--model_prediction_type`) +- 推奨設定の組み合わせ + +
+ +### 6.3. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 + +You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control. + +These settings are specified via the `network_args` argument. + +* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`. + * Example: `--network_args "network_reg_dims=attn.*.q_proj=4,attn.*.k_proj=4"` +* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`. + * Example: `--network_args "network_reg_lrs=down_blocks.1=1e-4,up_blocks.2=2e-4"` + +**Notes:** + +* To find the correct module names for the patterns, you may need to inspect the model structure. +* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings. +* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied. + +
+日本語 + +正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、柔軟できめ細やかな制御が可能になります。 + +これらの設定は `network_args` 引数で指定します。 + +* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。 +* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。 + +**注意点:** + +* パターンのための正確なモジュール名を見つけるには、モデルの構造を調べる必要があるかもしれません。 +* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。 +* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。 + +
+ +### 6.4. Multi-Resolution Training / マルチ解像度トレーニング + +You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution. + +**Note:** This feature is available, but it is **not recommended** as the HunyuanImage-2.1 base model was not trained with multi-resolution capabilities. Using it may lead to unexpected results. + +Configuration file example: +```toml +[general] +shuffle_caption = true +caption_extension = ".txt" + +[[datasets]] +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/directory" + num_repeats = 1 + +[[datasets]] +batch_size = 1 +enable_bucket = true +resolution = [1280, 768] + + [[datasets.subsets]] + image_dir = "path/to/another/directory" + num_repeats = 1 +``` + +
+日本語 + +データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。 + +**注意:** この機能は利用可能ですが、HunyuanImage-2.1のベースモデルはマルチ解像度で学習されていないため、**非推奨**です。使用すると予期しない結果になる可能性があります。 + +設定ファイルの例は英語のドキュメントを参照してください。 + +
+ +### 6.5. Validation / 検証 + +You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. This feature works the same as in other training scripts. For details, please refer to the [Validation Guide](validation.md). + +
+日本語 + +学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。この機能は他の学習スクリプトと同様に動作します。詳細は[検証ガイド](validation.md)を参照してください。 + +
+ +## 7. Other Training Options / その他の学習オプション + +- **`--ip_noise_gamma`**: Use `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` to adjust Input Perturbation noise gamma values during training. See Stable Diffusion 3 training options for details. + +- **`--loss_type`**: Specifies the loss function for training. The default is `l2`. + - `l1`: L1 loss. + - `l2`: L2 loss (mean squared error). + - `huber`: Huber loss. + - `smooth_l1`: Smooth L1 loss. + +- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: These are parameters for Huber loss. They are used when `--loss_type` is `huber` or `smooth_l1`. + +- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: These options allow you to adjust the loss weighting for each timestep. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md). + +- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. + +
+日本語 + +- **`--ip_noise_gamma`**: Input Perturbationノイズのガンマ値を調整します。 +- **`--loss_type`**: 学習に用いる損失関数を指定します。 +- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータです。 +- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: 各タイムステップの損失の重み付けを調整します。 +- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップを融合してVRAM使用量を削減します。 + +
+ +## 8. Using the Inference Script / 推論スクリプトの使用法 + +The `hunyuan_image_minimal_inference.py` script allows you to generate images using trained LoRA models. Here's a basic usage example: + +```bash +python hunyuan_image_minimal_inference.py \ + --dit "" \ + --text_encoder "" \ + --byt5 "" \ + --vae "" \ + --lora_weight "" \ + --lora_multiplier 1.0 \ + --attn_mode "torch" \ + --prompt "A cute cartoon penguin in a snowy landscape" \ + --image_size 2048 2048 \ + --infer_steps 50 \ + --guidance_scale 3.5 \ + --flow_shift 5.0 \ + --seed 542017 \ + --save_path "output_image.png" +``` + +**Key Options:** +- `--fp8_scaled`: Use scaled FP8 format for reduced VRAM usage during inference +- `--blocks_to_swap`: Swap blocks to CPU to reduce VRAM usage +- `--image_size`: Resolution in **height width** (inference is most stable at 2560x1536, 2304x1792, 2048x2048, 1792x2304, 1536x2560 according to the official repo) +- `--guidance_scale`: CFG scale (default: 3.5) +- `--flow_shift`: Flow matching shift parameter (default: 5.0) +- `--text_encoder_cpu`: Run the text encoders on CPU to reduce VRAM usage +- `--vae_chunk_size`: Chunk size for VAE decoding to reduce memory usage (default: None, no chunking). 16 is recommended if enabled. +- `--apg_start_step_general` and `--apg_start_step_ocr`: Start steps for APG (Adaptive Projected Guidance) if using APG during inference. `5` and `38` are the official recommended values for 50 steps. If this value exceeds `--infer_steps`, APG will not be applied. +- `--guidance_rescale`: Rescales the guidance for steps before APG starts. Default is `0.0` (no rescaling). If you use this option, a value around `0.5` might be good starting point. +- `--guidance_rescale_apg`: Rescales the guidance for APG. Default is `0.0` (no rescaling). This option doesn't seem to have a large effect, but if you use it, a value around `0.5` might be a good starting point. + +`--split_attn` is not supported (since inference is done one at a time). `--fp8_vl` is not supported, please use CPU for the text encoder if VRAM is insufficient. + +
+日本語 + +`hunyuan_image_minimal_inference.py`スクリプトを使用して、学習したLoRAモデルで画像を生成できます。基本的な使用例は英語のドキュメントを参照してください。 + +**主要なオプション:** +- `--fp8_scaled`: VRAM使用量削減のためのスケールFP8形式 +- `--blocks_to_swap`: VRAM使用量削減のためのブロックスワップ +- `--image_size`: 解像度(2048x2048で最も安定) +- `--guidance_scale`: CFGスケール(推奨: 3.5) +- `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0) +- `--text_encoder_cpu`: テキストエンコーダをCPUで実行してVRAM使用量削減 +- `--vae_chunk_size`: VAEデコーディングのチャンクサイズ(デフォルト: None、チャンク処理なし)。有効にする場合は16を推奨。 +- `--apg_start_step_general` と `--apg_start_step_ocr`: 推論中にAPGを使用する場合の開始ステップ。50ステップの場合、公式推奨値はそれぞれ5と38です。この値が`--infer_steps`を超えると、APGは適用されません。 +- `--guidance_rescale`: APG開始前のステップに対するガイダンスのリスケーリング。デフォルトは0.0(リスケーリングなし)。使用する場合、0.5程度から始めて調整してください。 +- `--guidance_rescale_apg`: APGに対するガイダンスのリスケーリング。デフォルトは0.0(リスケーリングなし)。このオプションは大きな効果はないようですが、使用する場合は0.5程度から始めて調整してください。 + +`--split_attn`はサポートされていません(1件ずつ推論するため)。`--fp8_vl`もサポートされていません。VRAMが不足する場合はテキストエンコーダをCPUで実行してください。 + +
+ +## 9. Related Tools / 関連ツール + +### `networks/convert_hunyuan_image_lora_to_comfy.py` + +A script to convert LoRA models to ComfyUI-compatible format. The formats differ slightly, so conversion is necessary. You can convert from the sd-scripts format to ComfyUI format with: + +```bash +python networks/convert_hunyuan_image_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors +``` + +Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted. + +
+日本語 + +**`networks/convert_hunyuan_image_lora_to_comfy.py`** + +LoRAモデルをComfyUI互換形式に変換するスクリプト。わずかに形式が異なるため、変換が必要です。以下の指定で、sd-scriptsの形式からComfyUI形式に変換できます。 + +```bash +python networks/convert_hunyuan_image_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors +``` + +`--reverse`オプションを付けると、逆変換(ComfyUI形式からsd-scripts形式)も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。 + +
+ +## 10. Others / その他 + +`hunyuan_image_train_network.py` includes many features common with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these features, refer to the [`train_network.py` guide](train_network.md#5-other-features--その他の機能) or the script help (`python hunyuan_image_train_network.py --help`). + +
+日本語 + +`hunyuan_image_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python hunyuan_image_train_network.py --help`) を参照してください。 + +
diff --git a/docs/lumina_train_network.md b/docs/lumina_train_network.md new file mode 100644 index 000000000..80ae84187 --- /dev/null +++ b/docs/lumina_train_network.md @@ -0,0 +1,319 @@ +# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド + +This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository. + +## 1. Introduction / はじめに + +`lumina_train_network.py` trains additional networks such as LoRA for Lumina Image 2.0 models. Lumina Image 2.0 adopts a Next-DiT (Next-generation Diffusion Transformer) architecture, which differs from previous Stable Diffusion models. It uses a single text encoder (Gemma2) and a dedicated AutoEncoder (AE). + +This guide assumes you already understand the basics of LoRA training. For common usage and options, see [the train_network.py guide](./train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md). + +**Prerequisites:** + +* The `sd-scripts` repository has been cloned and the Python environment is ready. +* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md). +* Lumina Image 2.0 model files for training are available. + +
+日本語 + +`lumina_train_network.py`は、Lumina Image 2.0モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。Lumina Image 2.0は、Next-DiT (Next-generation Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。テキストエンコーダーとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、`train_network.py`のガイド(作成中)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください) +* 学習対象のLumina Image 2.0モデルファイルが準備できていること。 +
+ +## 2. Differences from `train_network.py` / `train_network.py` との違い + +`lumina_train_network.py` is based on `train_network.py` but modified for Lumina Image 2.0. Main differences are: + +* **Target models:** Lumina Image 2.0 models. +* **Model structure:** Uses Next-DiT (Transformer based) instead of U-Net and employs a single text encoder (Gemma2). The AutoEncoder (AE) is not compatible with SDXL/SD3/FLUX. +* **Arguments:** Options exist to specify the Lumina Image 2.0 model, Gemma2 text encoder and AE. With a single `.safetensors` file, these components are typically provided separately. +* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. +* **Lumina specific options:** Additional parameters for timestep sampling, model prediction type, discrete flow shift, and system prompt. + +
+日本語 +`lumina_train_network.py`は`train_network.py`をベースに、Lumina Image 2.0モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** Lumina Image 2.0モデルを対象とします。 +* **モデル構造:** U-Netの代わりにNext-DiT (Transformerベース) を使用します。Text EncoderとしてGemma2を単体で使用し、専用のAutoEncoder (AE) を使用します。 +* **引数:** Lumina Image 2.0モデル、Gemma2 Text Encoder、AEを指定する引数があります。通常、これらのコンポーネントは個別に提供されます。 +* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はLumina Image 2.0の学習では使用されません。 +* **Lumina特有の引数:** タイムステップのサンプリング、モデル予測タイプ、離散フローシフト、システムプロンプトに関する引数が追加されています。 +
+ +## 3. Preparation / 準備 + +The following files are required before starting training: + +1. **Training script:** `lumina_train_network.py` +2. **Lumina Image 2.0 model file:** `.safetensors` file for the base model. +3. **Gemma2 text encoder file:** `.safetensors` file for the text encoder. +4. **AutoEncoder (AE) file:** `.safetensors` file for the AE. +5. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md). In this document we use `my_lumina_dataset_config.toml` as an example. + + +**Model Files:** +* Lumina Image 2.0: `lumina-image-2.safetensors` ([full precision link](https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors)) or `lumina_2_model_bf16.safetensors` ([bf16 link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors)) +* Gemma2 2B (fp16): `gemma-2-2b.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/text_encoders/gemma_2_2b_fp16.safetensors)) +* AutoEncoder: `ae.safetensors` ([link](https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/vae/ae.safetensors)) (same as FLUX) + + +
+日本語 +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `lumina_train_network.py` +2. **Lumina Image 2.0モデルファイル:** 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイル。 +3. **Gemma2テキストエンコーダーファイル:** Gemma2テキストエンコーダーの`.safetensors`ファイル。 +4. **AutoEncoder (AE) ファイル:** AEの`.safetensors`ファイル。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。 + * 例として`my_lumina_dataset_config.toml`を使用します。 + +**モデルファイル** は英語ドキュメントの通りです。 + +
+ +## 4. Running the Training / 学習の実行 + +Execute `lumina_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Lumina Image 2.0 specific options must be supplied. + +Example command: + +```bash +accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ + --pretrained_model_name_or_path="lumina-image-2.safetensors" \ + --gemma2="gemma-2-2b.safetensors" \ + --ae="ae.safetensors" \ + --dataset_config="my_lumina_dataset_config.toml" \ + --output_dir="./output" \ + --output_name="my_lumina_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_lumina \ + --network_dim=8 \ + --network_alpha=8 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW" \ + --lr_scheduler="constant" \ + --timestep_sampling="nextdit_shift" \ + --discrete_flow_shift=6.0 \ + --model_prediction_type="raw" \ + --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --cache_text_encoder_outputs +``` + +*(Write the command on one line or use `\` or `^` for line breaks.)* + +
+日本語 +学習は、ターミナルから`lumina_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Lumina Image 2.0特有の引数を指定する必要があります。 + +以下に、基本的なコマンドライン実行例を示します。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 lumina_train_network.py \ + --pretrained_model_name_or_path="lumina-image-2.safetensors" \ + --gemma2="gemma-2-2b.safetensors" \ + --ae="ae.safetensors" \ + --dataset_config="my_lumina_dataset_config.toml" \ + --output_dir="./output" \ + --output_name="my_lumina_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora_lumina \ + --network_dim=8 \ + --network_alpha=8 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW" \ + --lr_scheduler="constant" \ + --timestep_sampling="nextdit_shift" \ + --discrete_flow_shift=6.0 \ + --model_prediction_type="raw" \ + --system_prompt="You are an assistant designed to generate high-quality images based on user prompts." \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="bf16" \ + --gradient_checkpointing \ + --cache_latents \ + --cache_text_encoder_outputs +``` + +※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。 +
+ +### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 + +Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Lumina Image 2.0 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide. + +#### Model Options / モデル関連 + +* `--pretrained_model_name_or_path=""` **required** – Path to the Lumina Image 2.0 model. +* `--gemma2=""` **required** – Path to the Gemma2 text encoder `.safetensors` file. +* `--ae=""` **required** – Path to the AutoEncoder `.safetensors` file. + +#### Lumina Image 2.0 Training Parameters / Lumina Image 2.0 学習パラメータ + +* `--gemma2_max_token_length=` – Max token length for Gemma2. Default is 256. +* `--timestep_sampling=` – Timestep sampling method. Options: `sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`. Default `shift`. **Recommended: `nextdit_shift`** +* `--discrete_flow_shift=` – Discrete flow shift for the Euler Discrete Scheduler. Default `6.0`. +* `--model_prediction_type=` – Model prediction processing method. Options: `raw`, `additive`, `sigma_scaled`. Default `raw`. **Recommended: `raw`** +* `--system_prompt=` – System prompt to prepend to all prompts. Recommended: `"You are an assistant designed to generate high-quality images based on user prompts."` or `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` +* `--use_flash_attn` – Use Flash Attention. Requires `pip install flash-attn` (may not be supported in all environments). If installed correctly, it speeds up training. +* `--use_sage_attn` – Use Sage Attention for the model. +* `--sample_batch_size=` – Batch size to use for sampling, defaults to `--training_batch_size` value. Sample batches are bucketed by width, height, guidance scale, and seed. +* `--sigmoid_scale=` – Scale factor for sigmoid timestep sampling. Default `1.0`. + +#### Memory and Speed / メモリ・速度関連 + +* `--blocks_to_swap=` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`. +* `--cache_text_encoder_outputs` – Cache Gemma2 outputs to reduce memory usage. +* `--cache_latents`, `--cache_latents_to_disk` – Cache AE outputs. +* `--fp8_base` – Use FP8 precision for the base model. + +#### Network Arguments / ネットワーク引数 + +For Lumina Image 2.0, you can specify different dimensions for various components: + +* `--network_args` can include: + * `"attn_dim=4"` – Attention dimension + * `"mlp_dim=4"` – MLP dimension + * `"mod_dim=4"` – Modulation dimension + * `"refiner_dim=4"` – Refiner blocks dimension + * `"embedder_dims=[4,4,4]"` – Embedder dimensions for x, t, and caption embedders + +#### Incompatible or Deprecated Options / 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for Lumina Image 2.0. + +
+日本語 + +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。 + +#### モデル関連 + +* `--pretrained_model_name_or_path=""` **[必須]** + * 学習のベースとなるLumina Image 2.0モデルの`.safetensors`ファイルのパスを指定します。 +* `--gemma2=""` **[必須]** + * Gemma2テキストエンコーダーの`.safetensors`ファイルのパスを指定します。 +* `--ae=""` **[必須]** + * AutoEncoderの`.safetensors`ファイルのパスを指定します。 + +#### Lumina Image 2.0 学習パラメータ + +* `--gemma2_max_token_length=` – Gemma2で使用するトークンの最大長を指定します。デフォルトは256です。 +* `--timestep_sampling=` – タイムステップのサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `nextdit_shift`から選択します。デフォルトは`shift`です。**推奨: `nextdit_shift`** +* `--discrete_flow_shift=` – Euler Discrete Schedulerの離散フローシフトを指定します。デフォルトは`6.0`です。 +* `--model_prediction_type=` – モデル予測の処理方法を指定します。`raw`, `additive`, `sigma_scaled`から選択します。デフォルトは`raw`です。**推奨: `raw`** +* `--system_prompt=` – 全てのプロンプトに前置するシステムプロンプトを指定します。推奨: `"You are an assistant designed to generate high-quality images based on user prompts."` または `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` +* `--use_flash_attn` – Flash Attentionを使用します。`pip install flash-attn`でインストールが必要です(環境によってはサポートされていません)。正しくインストールされている場合は、指定すると学習が高速化されます。 +* `--use_sage_attn` – Sage Attentionを使用します。 +* `--sample_batch_size=` – サンプリングに使用するバッチサイズ。デフォルトは `--training_batch_size` の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます。 +* `--sigmoid_scale=` – sigmoidタイムステップサンプリングのスケール係数を指定します。デフォルトは`1.0`です。 + +#### メモリ・速度関連 + +* `--blocks_to_swap=` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。 +* `--cache_text_encoder_outputs` – Gemma2の出力をキャッシュしてメモリ使用量を削減します。 +* `--cache_latents`, `--cache_latents_to_disk` – AEの出力をキャッシュします。 +* `--fp8_base` – ベースモデルにFP8精度を使用します。 + +#### ネットワーク引数 + +Lumina Image 2.0では、各コンポーネントに対して異なる次元を指定できます: + +* `--network_args` には以下を含めることができます: + * `"attn_dim=4"` – アテンション次元 + * `"mlp_dim=4"` – MLP次元 + * `"mod_dim=4"` – モジュレーション次元 + * `"refiner_dim=4"` – リファイナーブロック次元 + * `"embedder_dims=[4,4,4]"` – x、t、キャプションエンベッダーのエンベッダー次元 + +#### 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、Lumina Image 2.0学習では使用されません。 +
+ +### 4.2. Starting Training / 学習の開始 + +After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始). + +## 5. Using the Trained Model / 学習済みモデルの利用 + +When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes. + +### Inference with scripts in this repository / このリポジトリのスクリプトを使用した推論 + +The inference script is also available. The script is `lumina_minimal_inference.py`. See `--help` for options. + +``` +python lumina_minimal_inference.py --pretrained_model_name_or_path path/to/lumina.safetensors --gemma2_path path/to/gemma.safetensors" --ae_path path/to/flux_ae.safetensors --output_dir path/to/output_dir --offload --seed 1234 --prompt "Positive prompt" --system_prompt "You are an assistant designed to generate high-quality images based on user prompts." --negative_prompt "negative prompt" +``` + +`--add_system_prompt_to_negative_prompt` option can be used to add the system prompt to the negative prompt. + +`--lora_weights` option can be used to specify the LoRA weights file, and optional multiplier (like `path;1.0`). + +## 6. Others / その他 + +`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`. + +### 6.1. Recommended Settings / 推奨設定 + +Based on the contributor's recommendations, here are the suggested settings for optimal training: + +**Key Parameters:** +* `--timestep_sampling="nextdit_shift"` +* `--discrete_flow_shift=6.0` +* `--model_prediction_type="raw"` +* `--mixed_precision="bf16"` + +**System Prompts:** +* General purpose: `"You are an assistant designed to generate high-quality images based on user prompts."` +* High image-text alignment: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` + +**Sample Prompts:** +Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) parameters: +* `--ctr 0.25 --rcfg 1.0` (default values) + +
+日本語 + +必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 + +学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。 + +当リポジトリ内の推論スクリプトを用いて推論することも可能です。スクリプトは`lumina_minimal_inference.py`です。オプションは`--help`で確認できます。記述例は英語版のドキュメントをご確認ください。 + +`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。 + +### 6.1. 推奨設定 + +コントリビューターの推奨に基づく、最適な学習のための推奨設定: + +**主要パラメータ:** +* `--timestep_sampling="nextdit_shift"` +* `--discrete_flow_shift=6.0` +* `--model_prediction_type="raw"` +* `--mixed_precision="bf16"` + +**システムプロンプト:** +* 汎用目的: `"You are an assistant designed to generate high-quality images based on user prompts."` +* 高い画像-テキスト整合性: `"You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts."` + +**サンプルプロンプト:** +サンプルプロンプトには CFG truncate (`--ctr`) と Renorm CFG (`--rcfg`) パラメータを含めることができます: +* `--ctr 0.25 --rcfg 1.0` (デフォルト値) + +
\ No newline at end of file diff --git a/docs/sd3_train_network.md b/docs/sd3_train_network.md new file mode 100644 index 000000000..30876ce05 --- /dev/null +++ b/docs/sd3_train_network.md @@ -0,0 +1,355 @@ +# LoRA Training Guide for Stable Diffusion 3/3.5 using `sd3_train_network.py` / `sd3_train_network.py` を用いたStable Diffusion 3/3.5モデルのLoRA学習ガイド + +This document explains how to train LoRA (Low-Rank Adaptation) models for Stable Diffusion 3 (SD3) and Stable Diffusion 3.5 (SD3.5) using `sd3_train_network.py` in the `sd-scripts` repository. + +## 1. Introduction / はじめに + +`sd3_train_network.py` trains additional networks such as LoRA for SD3/3.5 models. SD3 adopts a new architecture called MMDiT (Multi-Modal Diffusion Transformer), so its structure differs from previous Stable Diffusion models. With this script you can create LoRA models specialized for SD3/3.5. + +This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are the same as those in [`sdxl_train_network.py`](sdxl_train_network.md). + +**Prerequisites:** + +* The `sd-scripts` repository has been cloned and the Python environment is ready. +* A training dataset has been prepared. See the [Dataset Configuration Guide](link/to/dataset/config/doc). +* SD3/3.5 model files for training are available. + +
+日本語 + +`sd3_train_network.py`は、Stable Diffusion 3/3.5モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。SD3は、MMDiT (Multi-Modal Diffusion Transformer) と呼ばれる新しいアーキテクチャを採用しており、従来のStable Diffusionモデルとは構造が異なります。このスクリプトを使用することで、SD3/3.5モデルに特化したLoRAモデルを作成できます。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) +* 学習対象のSD3/3.5モデルファイルが準備できていること。 +
+ +## 2. Differences from `train_network.py` / `train_network.py` との違い + +`sd3_train_network.py` is based on `train_network.py` but modified for SD3/3.5. Main differences are: + +* **Target models:** Stable Diffusion 3 and 3.5 Medium/Large. +* **Model structure:** Uses MMDiT (Transformer based) instead of U-Net and employs three text encoders: CLIP-L, CLIP-G and T5-XXL. The VAE is not compatible with SDXL. +* **Arguments:** Options exist to specify the SD3/3.5 model, text encoders and VAE. With a single `.safetensors` file, these paths are detected automatically, so separate paths are optional. +* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. +* **SD3 specific options:** Additional parameters for attention masks, dropout rates, positional embedding adjustments (for SD3.5), timestep sampling and loss weighting. + +
+日本語 +`sd3_train_network.py`は`train_network.py`をベースに、SD3/3.5モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** Stable Diffusion 3, 3.5 Medium / Large モデルを対象とします。 +* **モデル構造:** U-Netの代わりにMMDiT (Transformerベース) を使用します。Text EncoderとしてCLIP-L, CLIP-G, T5-XXLの三つを使用します。VAEはSDXLと互換性がありません。 +* **引数:** SD3/3.5モデル、Text Encoder群、VAEを指定する引数があります。ただし、単一ファイルの`.safetensors`形式であれば、内部で自動的に分離されるため、個別のパス指定は必須ではありません。 +* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はSD3/3.5の学習では使用されません。 +* **SD3特有の引数:** Text Encoderのアテンションマスクやドロップアウト率、Positional Embeddingの調整(SD3.5向け)、タイムステップのサンプリングや損失の重み付けに関する引数が追加されています。 +
+ +## 3. Preparation / 準備 + +The following files are required before starting training: + +1. **Training script:** `sd3_train_network.py` +2. **SD3/3.5 model file:** `.safetensors` file for the base model and paths to each text encoder. Single-file format can also be used. +3. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](link/to/dataset/config/doc).) In this document we use `my_sd3_dataset_config.toml` as an example. + +
+日本語 +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `sd3_train_network.py` +2. **SD3/3.5モデルファイル:** 学習のベースとなるSD3/3.5モデルの`.safetensors`ファイル。またText Encoderをそれぞれ対応する引数でパスを指定します。 + * 単一ファイル形式も使用可能です。 +3. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。 + * 例として`my_sd3_dataset_config.toml`を使用します。 +
+ +## 4. Running the Training / 学習の実行 + +Execute `sd3_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but SD3/3.5 specific options must be supplied. + +Example command: + +```bash +accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py \ + --pretrained_model_name_or_path="" \ + --clip_l="" \ + --clip_g="" \ + --t5xxl="" \ + --dataset_config="my_sd3_dataset_config.toml" \ + --output_dir="" \ + --output_name="my_sd3_lora" \ + --save_model_as=safetensors \ + --network_module=networks.lora \ + --network_dim=16 \ + --network_alpha=1 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW8bit" \ + --lr_scheduler="constant" \ + --sdpa \ + --max_train_epochs=10 \ + --save_every_n_epochs=1 \ + --mixed_precision="fp16" \ + --gradient_checkpointing \ + --weighting_scheme="sigma_sqrt" \ + --blocks_to_swap=32 +``` + +*(Write the command on one line or use `\` or `^` for line breaks.)* + +
+日本語 + +学習は、ターミナルから`sd3_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、SD3/3.5特有の引数を指定する必要があります。 + +以下に、基本的なコマンドライン実行例を示します。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 sd3_train_network.py + --pretrained_model_name_or_path="" + --clip_l="" + --clip_g="" + --t5xxl="" + --dataset_config="my_sd3_dataset_config.toml" + --output_dir="" + --output_name="my_sd3_lora" + --save_model_as=safetensors + --network_module=networks.lora + --network_dim=16 + --network_alpha=1 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --lr_scheduler="constant" + --sdpa + --max_train_epochs=10 + --save_every_n_epochs=1 + --mixed_precision="fp16" + --gradient_checkpointing + --weighting_scheme="sigma_sqrt" + --blocks_to_swap=32 +``` + +※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。 + +
+ +### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 + +Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following SD3/3.5 options. For shared options (`--output_dir`, `--output_name`, etc.), see that guide. + +#### Model Options / モデル関連 + +* `--pretrained_model_name_or_path=""` **required** – Path to the SD3/3.5 model. +* `--clip_l`, `--clip_g`, `--t5xxl`, `--vae` – Skip these if the base model is a single file; otherwise specify each `.safetensors` path. `--vae` is usually unnecessary unless you use a different VAE. + +#### SD3/3.5 Training Parameters / SD3/3.5 学習パラメータ + +* `--t5xxl_max_token_length=` – Max token length for T5-XXL. Default `256`. +* `--apply_lg_attn_mask` – Apply an attention mask to CLIP-L/CLIP-G outputs. +* `--apply_t5_attn_mask` – Apply an attention mask to T5-XXL outputs. +* `--clip_l_dropout_rate`, `--clip_g_dropout_rate`, `--t5_dropout_rate` – Dropout rates for the text encoders. Default `0.0`. +* `--pos_emb_random_crop_rate=` **[SD3.5]** – Probability of randomly cropping the positional embedding. +* `--enable_scaled_pos_embed` **[SD3.5][experimental]** – Scale positional embeddings when training with multiple resolutions. +* `--training_shift=` – Shift applied to the timestep distribution. Default `1.0`. +* `--weighting_scheme=` – Weighting method for loss by timestep. Default `uniform`. +* `--logit_mean=` – Mean value for `logit_normal` weighting scheme. Default `0.0`. +* `--logit_std=` – Standard deviation for `logit_normal` weighting scheme. Default `1.0`. +* `--mode_scale=` – Scale factor for `mode` weighting scheme. Default `1.29`. + +#### Memory and Speed / メモリ・速度関連 + +* `--blocks_to_swap=` **[experimental]** – Swap a number of Transformer blocks between CPU and GPU. More blocks reduce VRAM but slow training. Cannot be used with `--cpu_offload_checkpointing`. +* `--cache_text_encoder_outputs` – Caches the outputs of the text encoders to reduce VRAM usage and speed up training. This is particularly effective for SD3, which uses three text encoders. Recommended when not training the text encoder LoRA. For more details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md). +* `--cache_text_encoder_outputs_to_disk` – Caches the text encoder outputs to disk when the above option is enabled. +* `--t5xxl_device=` **[not supported yet]** – Specifies the device for T5-XXL model. If not specified, uses accelerator's device. +* `--t5xxl_dtype=` **[not supported yet]** – Specifies the dtype for T5-XXL model. If not specified, uses default dtype from mixed precision. +* `--save_clip` **[not supported yet]** – Saves CLIP models to checkpoint (unified checkpoint format not yet supported). +* `--save_t5xxl` **[not supported yet]** – Saves T5-XXL model to checkpoint (unified checkpoint format not yet supported). + +#### Incompatible or Deprecated Options / 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Options for Stable Diffusion v1/v2 that are not used for SD3/3.5. + +
+日本語 + +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のSD3/3.5特有の引数を指定します。共通の引数については、上記ガイドを参照してください。 + +#### モデル関連 + +* `--pretrained_model_name_or_path=""` **[必須]** + * 学習のベースとなるSD3/3.5モデルの`.safetensors`ファイルのパスを指定します。 +* `--clip_l`, `--clip_g`, `--t5xxl`, `--vae`: + * ベースモデルが単一ファイル形式の場合、これらの指定は不要です(自動的にモデル内部から読み込まれます)。 + * Text Encoderが別ファイルとして提供されている場合は、それぞれの`.safetensors`ファイルのパスを指定します。`--vae` はベースモデルに含まれているため、通常は指定する必要はありません(明示的に異なるVAEを使用する場合のみ指定)。 + +#### SD3/3.5 学習パラメータ + +* `--t5xxl_max_token_length=` – T5-XXLで使用するトークンの最大長を指定します。デフォルトは`256`です。 +* `--apply_lg_attn_mask` – CLIP-L/CLIP-Gの出力にパディング用のマスクを適用します。 +* `--apply_t5_attn_mask` – T5-XXLの出力にパディング用のマスクを適用します。 +* `--clip_l_dropout_rate`, `--clip_g_dropout_rate`, `--t5_dropout_rate` – 各Text Encoderのドロップアウト率を指定します。デフォルトは`0.0`です。 +* `--pos_emb_random_crop_rate=` **[SD3.5向け]** – Positional Embeddingにランダムクロップを適用する確率を指定します。 +* `--enable_scaled_pos_embed` **[SD3.5向け][実験的機能]** – マルチ解像度学習時に解像度に応じてPositional Embeddingをスケーリングします。 +* `--training_shift=` – タイムステップ分布を調整するためのシフト値です。デフォルトは`1.0`です。 +* `--weighting_scheme=` – タイムステップに応じた損失の重み付け方法を指定します。デフォルトは`uniform`です。 +* `--logit_mean=` – `logit_normal`重み付けスキームの平均値です。デフォルトは`0.0`です。 +* `--logit_std=` – `logit_normal`重み付けスキームの標準偏差です。デフォルトは`1.0`です。 +* `--mode_scale=` – `mode`重み付けスキームのスケール係数です。デフォルトは`1.29`です。 + +#### メモリ・速度関連 + +* `--blocks_to_swap=` **[実験的機能]** – TransformerブロックをCPUとGPUでスワップしてVRAMを節約します。`--cpu_offload_checkpointing`とは併用できません。 +* `--cache_text_encoder_outputs` – Text Encoderの出力をキャッシュし、VRAM使用量削減と学習高速化を図ります。SD3は3つのText Encoderを持つため特に効果的です。Text EncoderのLoRAを学習しない場合に推奨されます。詳細は[`sdxl_train_network.py`のガイド](sdxl_train_network.md)を参照してください。 +* `--cache_text_encoder_outputs_to_disk` – 上記オプションと併用し、Text Encoderの出力をディスクにキャッシュします。 +* `--t5xxl_device=` **[未サポート]** – T5-XXLモデルのデバイスを指定します。指定しない場合はacceleratorのデバイスを使用します。 +* `--t5xxl_dtype=` **[未サポート]** – T5-XXLモデルのdtypeを指定します。指定しない場合はデフォルトのdtype(mixed precisionから)を使用します。 +* `--save_clip` **[未サポート]** – CLIPモデルをチェックポイントに保存します(統合チェックポイント形式は未サポート)。 +* `--save_t5xxl` **[未サポート]** – T5-XXLモデルをチェックポイントに保存します(統合チェックポイント形式は未サポート)。 + +#### 非互換・非推奨の引数 + +* `--v2`, `--v_parameterization`, `--clip_skip` – Stable Diffusion v1/v2向けの引数のため、SD3/3.5学習では使用されません。 + +
+ +### 4.2. Starting Training / 学習の開始 + +After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始). + +
+日本語 + +必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 + +
+ +## 5. LoRA Target Modules / LoRAの学習対象モジュール + +When training LoRA with `sd3_train_network.py`, the following modules are targeted by default: + +* **MMDiT (replaces U-Net)**: + * `qkv` (Query, Key, Value) matrices and `proj_out` (output projection) in the attention blocks. +* **final_layer**: + * The output layer at the end of MMDiT. + +By using `--network_args`, you can apply more detailed controls, such as setting different ranks (dimensions) for each module. + +### Specify rank for each layer in SD3 LoRA / 各層のランクを指定する + +You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: +``` +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`. + +### Specify blocks to train in SD3 LoRA training + +You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. + +The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_block_indices=1,2,6-8" +``` + +
+日本語 + +`sd3_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。 + +* **MMDiT (U-Netの代替)**: + * Attentionブロック内の`qkv`(Query, Key, Value)行列と、`proj_out`(出力Projection)。 +* **final_layer**: + * MMDiTの最後にある出力層。 + +`--network_args` を使用することで、モジュールごとに異なるランク(次元数)を設定するなど、より詳細な制御が可能です。 + +### SD3 LoRAで各層のランクを指定する + +各層のランクを指定するには、`--network_args`オプションを使用します。`0`を指定すると、その層にはLoRAが適用されません。 + +network_argsが指定されない場合、デフォルト値(`network_dim`)が適用されます。 + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"`を指定すると、各層のランクが表示されます。 + +例: + +```bash +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +また、`emb_dims`を指定することで、SD3の条件付け層にLoRAを適用することもできます。指定する際は、必ず`[]`内にカンマ区切りで6つの数字を指定してください。 + +```bash +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +各数字は、`context_embedder`、`t_embedder`、`x_embedder`、`y_embedder`、`final_layer_adaLN_modulation`、`final_layer_linear`に対応しています。上記の例では、すべての条件付け層にLoRAを適用し、`context_embedder`に2、`t_embedder`に3、`x_embedder`に4、`y_embedder`に5、`final_layer_adaLN_modulation`に6、`final_layer_linear`に7のランクを設定しています。 + +`0`を指定すると、その層にはLoRAが適用されません。例えば、`[4,0,0,4,0,0]`と指定すると、`context_embedder`と`y_embedder`のみにLoRAが適用されます。 + +
+ + +## 6. Using the Trained Model / 学習済みモデルの利用 + +When training finishes, a LoRA model file (e.g. `my_sd3_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support SD3/3.5, such as ComfyUI. + +
+日本語 + +学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_sd3_lora.safetensors`)が保存されます。このファイルは、SD3/3.5モデルに対応した推論環境(例: ComfyUIなど)で使用できます。 + +
+ + +## 7. Others / その他 + +`sd3_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python sd3_train_network.py --help`. + +
+日本語 + +`sd3_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python sd3_train_network.py --help`) を参照してください。 + +
diff --git a/docs/sdxl_train_network.md b/docs/sdxl_train_network.md new file mode 100644 index 000000000..e1f6e9b9b --- /dev/null +++ b/docs/sdxl_train_network.md @@ -0,0 +1,321 @@ +# How to Use the SDXL LoRA Training Script `sdxl_train_network.py` / SDXL LoRA学習スクリプト `sdxl_train_network.py` の使い方 + +This document explains the basic procedure for training a LoRA (Low-Rank Adaptation) model for SDXL (Stable Diffusion XL) using `sdxl_train_network.py` included in the `sd-scripts` repository. + +
+日本語 +このドキュメントでは、`sd-scripts` リポジトリに含まれる `sdxl_train_network.py` を使用して、SDXL (Stable Diffusion XL) モデルに対する LoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 +
+ +## 1. Introduction / はじめに + +`sdxl_train_network.py` is a script for training additional networks such as LoRA for SDXL models. The basic usage is common with `train_network.py` (see [How to Use the LoRA Training Script `train_network.py`](train_network.md)), but SDXL model-specific settings are required. + +This guide focuses on SDXL LoRA training, explaining the main differences from `train_network.py` and SDXL-specific configuration items. + +**Prerequisites:** + +* You have cloned the `sd-scripts` repository and set up the Python environment. +* Your training dataset is ready. (Please refer to the [Dataset Preparation Guide](link/to/dataset/doc) for dataset preparation) +* You have read [How to Use the LoRA Training Script `train_network.py`](train_network.md). + +
+日本語 +`sdxl_train_network.py` は、SDXL モデルに対して LoRA などの追加ネットワークを学習させるためのスクリプトです。基本的な使い方は `train_network.py` ([LoRA学習スクリプト `train_network.py` の使い方](train_network.md) 参照) と共通ですが、SDXL モデル特有の設定が必要となります。 + +このガイドでは、SDXL LoRA 学習に焦点を当て、`train_network.py` との主な違いや SDXL 特有の設定項目を中心に説明します。 + +**前提条件:** + +* `sd-scripts` リポジトリのクローンと Python 環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット準備ガイド](link/to/dataset/doc)を参照してください) +* [LoRA学習スクリプト `train_network.py` の使い方](train_network.md) を一読していること。 +
+ +## 2. Preparation / 準備 + +Before starting training, you need the following files: + +1. **Training Script:** `sdxl_train_network.py` +2. **Dataset Definition File (.toml):** A TOML format file describing the training dataset configuration. + +### About the Dataset Definition File + +The basic format of the dataset definition file (`.toml`) is the same as for `train_network.py`. Please refer to the [Dataset Configuration Guide](link/to/dataset/config/doc) and [How to Use the LoRA Training Script `train_network.py`](train_network.md#about-the-dataset-definition-file). + +For SDXL, it is common to use high-resolution datasets and the aspect ratio bucketing feature (`enable_bucket = true`). + +In this example, we'll use a file named `my_sdxl_dataset_config.toml`. + +
+日本語 +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `sdxl_train_network.py` +2. **データセット定義ファイル (.toml):** 学習データセットの設定を記述した TOML 形式のファイル。 + +### データセット定義ファイルについて + +データセット定義ファイル (`.toml`) の基本的な書き方は `train_network.py` と共通です。[データセット設定ガイド](link/to/dataset/config/doc) および [LoRA学習スクリプト `train_network.py` の使い方](train_network.md#データセット定義ファイルについて) を参照してください。 + +SDXL では、高解像度のデータセットや、アスペクト比バケツ機能 (`enable_bucket = true`) の利用が一般的です。 + +ここでは、例として `my_sdxl_dataset_config.toml` という名前のファイルを使用することにします。 +
+ +## 3. Running the Training / 学習の実行 + +Training starts by running `sdxl_train_network.py` from the terminal. + +Here's a basic command line execution example for SDXL LoRA training: + +```bash +accelerate launch --num_cpu_threads_per_process 1 sdxl_train_network.py + --pretrained_model_name_or_path="" + --dataset_config="my_sdxl_dataset_config.toml" + --output_dir="" + --output_name="my_sdxl_lora" + --save_model_as=safetensors + --network_module=networks.lora + --network_dim=32 + --network_alpha=16 + --learning_rate=1e-4 + --unet_lr=1e-4 + --text_encoder_lr1=1e-5 + --text_encoder_lr2=1e-5 + --optimizer_type="AdamW8bit" + --lr_scheduler="constant" + --max_train_epochs=10 + --save_every_n_epochs=1 + --mixed_precision="bf16" + --gradient_checkpointing + --cache_text_encoder_outputs + --cache_latents +``` + +Comparing with the execution example of `train_network.py`, the following points are different: + +* The script to execute is `sdxl_train_network.py`. +* You specify an SDXL base model for `--pretrained_model_name_or_path`. +* `--text_encoder_lr` is split into `--text_encoder_lr1` and `--text_encoder_lr2` (since SDXL has two Text Encoders). +* `--mixed_precision` is recommended to be `bf16` or `fp16`. +* `--cache_text_encoder_outputs` and `--cache_latents` are recommended to reduce VRAM usage. + +Next, we'll explain the main command line arguments that differ from `train_network.py`. For common arguments, please refer to [How to Use the LoRA Training Script `train_network.py`](train_network.md#31-main-command-line-arguments). + +
+日本語 +学習は、ターミナルから `sdxl_train_network.py` を実行することで開始します。 + +以下に、SDXL LoRA 学習における基本的なコマンドライン実行例を示します。 + +```bash +accelerate launch --num_cpu_threads_per_process 1 sdxl_train_network.py + --pretrained_model_name_or_path="" + --dataset_config="my_sdxl_dataset_config.toml" + --output_dir="<学習結果の出力先ディレクトリ>" + --output_name="my_sdxl_lora" + --save_model_as=safetensors + --network_module=networks.lora + --network_dim=32 + --network_alpha=16 + --learning_rate=1e-4 + --unet_lr=1e-4 + --text_encoder_lr1=1e-5 + --text_encoder_lr2=1e-5 + --optimizer_type="AdamW8bit" + --lr_scheduler="constant" + --max_train_epochs=10 + --save_every_n_epochs=1 + --mixed_precision="bf16" + --gradient_checkpointing + --cache_text_encoder_outputs + --cache_latents +``` + +`train_network.py` の実行例と比較すると、以下の点が異なります。 + +* 実行するスクリプトが `sdxl_train_network.py` になります。 +* `--pretrained_model_name_or_path` には SDXL のベースモデルを指定します。 +* `--text_encoder_lr` が `--text_encoder_lr1` と `--text_encoder_lr2` に分かれています(SDXL は2つの Text Encoder を持つため)。 +* `--mixed_precision` は `bf16` または `fp16` が推奨されます。 +* `--cache_text_encoder_outputs` や `--cache_latents` は VRAM 使用量を削減するために推奨されます。 + +次に、`train_network.py` との差分となる主要なコマンドライン引数について解説します。共通の引数については、[LoRA学習スクリプト `train_network.py` の使い方](train_network.md#31-主要なコマンドライン引数) を参照してください。 +
+ +### 3.1. Main Command Line Arguments (Differences) / 主要なコマンドライン引数(差分) + +#### Model Related / モデル関連 + +* `--pretrained_model_name_or_path=""` **[Required]** + * Specifies the **SDXL model** to be used as the base for training. You can specify a Hugging Face Hub model ID (e.g., `"stabilityai/stable-diffusion-xl-base-1.0"`), a local Diffusers format model directory, or a path to a `.safetensors` file. +* `--v2`, `--v_parameterization` + * These arguments are for SD1.x/2.x. When using `sdxl_train_network.py`, since an SDXL model is assumed, these **typically do not need to be specified**. + +#### Dataset Related / データセット関連 + +* `--dataset_config=""` + * This is common with `train_network.py`. + * For SDXL, it is common to use high-resolution data and the bucketing feature (specify `enable_bucket = true` in the `.toml` file). + +#### Output & Save Related / 出力・保存関連 + +* These are common with `train_network.py`. + +#### LoRA Parameters / LoRA パラメータ + +* These are common with `train_network.py`. + +#### Training Parameters / 学習パラメータ + +* `--learning_rate=1e-4` + * Overall learning rate. This becomes the default value if `unet_lr`, `text_encoder_lr1`, and `text_encoder_lr2` are not specified. +* `--unet_lr=1e-4` + * Learning rate for LoRA modules in the U-Net part. If not specified, the value of `--learning_rate` is used. +* `--text_encoder_lr1=1e-5` + * Learning rate for LoRA modules in **Text Encoder 1 (OpenCLIP ViT-G/14)**. If not specified, the value of `--learning_rate` is used. A smaller value than U-Net is recommended. +* `--text_encoder_lr2=1e-5` + * Learning rate for LoRA modules in **Text Encoder 2 (CLIP ViT-L/14)**. If not specified, the value of `--learning_rate` is used. A smaller value than U-Net is recommended. +* `--optimizer_type="AdamW8bit"` + * Common with `train_network.py`. +* `--lr_scheduler="constant"` + * Common with `train_network.py`. +* `--lr_warmup_steps` + * Common with `train_network.py`. +* `--max_train_steps`, `--max_train_epochs` + * Common with `train_network.py`. +* `--mixed_precision="bf16"` + * Mixed precision training setting. For SDXL, `bf16` or `fp16` is recommended. Choose the one supported by your GPU. This reduces VRAM usage and improves training speed. +* `--gradient_accumulation_steps=1` + * Common with `train_network.py`. +* `--gradient_checkpointing` + * Common with `train_network.py`. Recommended to enable for SDXL due to its high memory consumption. +* `--cache_latents` + * Caches VAE outputs in memory (or on disk when `--cache_latents_to_disk` is specified). By skipping VAE computation, this reduces VRAM usage and speeds up training. Image augmentations (`--color_aug`, `--flip_aug`, `--random_crop`, etc.) are disabled. This option is recommended for SDXL training. +* `--cache_latents_to_disk` + * Used with `--cache_latents`, caches to disk. When loading the dataset for the first time, VAE outputs are cached to disk. This is recommended when you have a large number of training images, as it allows you to skip VAE computation on subsequent training runs. +* `--cache_text_encoder_outputs` + * Caches Text Encoder outputs in memory (or on disk when `--cache_text_encoder_outputs_to_disk` is specified). By skipping Text Encoder computation, this reduces VRAM usage and speeds up training. Caption augmentations (`--shuffle_caption`, `--caption_dropout_rate`, etc.) are disabled. + * **Note:** When using this option, LoRA modules for Text Encoder cannot be trained (`--network_train_unet_only` must be specified). +* `--cache_text_encoder_outputs_to_disk` + * Used with `--cache_text_encoder_outputs`, caches to disk. +* `--no_half_vae` + * Runs VAE in `float32` even when using mixed precision (`fp16`/`bf16`). Since SDXL's VAE can be unstable in `float16`, enable this when using `fp16`. +* `--clip_skip` + * Not normally used for SDXL. No need to specify. +* `--fused_backward_pass` + * Fuses gradient computation and optimizer steps to reduce VRAM usage. Available for SDXL. (Currently only supports the `Adafactor` optimizer) + +#### Others / その他 + +* `--seed`, `--logging_dir`, `--log_prefix`, etc. are common with `train_network.py`. + +
+日本語 +#### モデル関連 + +* `--pretrained_model_name_or_path="<モデルのパス>"` **[必須]** + * 学習のベースとなる **SDXL モデル**を指定します。Hugging Face Hub のモデル ID (例: `"stabilityai/stable-diffusion-xl-base-1.0"`) や、ローカルの Diffusers 形式モデルのディレクトリ、`.safetensors` ファイルのパスを指定できます。 +* `--v2`, `--v_parameterization` + * これらの引数は SD1.x/2.x 用です。`sdxl_train_network.py` を使用する場合、SDXL モデルであることが前提となるため、通常は**指定する必要はありません**。 + +#### データセット関連 + +* `--dataset_config="<設定ファイルのパス>"` + * `train_network.py` と共通です。 + * SDXL では高解像度データやバケツ機能 (`.toml` で `enable_bucket = true` を指定) の利用が一般的です。 + +#### 出力・保存関連 + +* `train_network.py` と共通です。 + +#### LoRA パラメータ + +* `train_network.py` と共通です。 + +#### 学習パラメータ + +* `--learning_rate=1e-4` + * 全体の学習率。`unet_lr`, `text_encoder_lr1`, `text_encoder_lr2` が指定されない場合のデフォルト値となります。 +* `--unet_lr=1e-4` + * U-Net 部分の LoRA モジュールに対する学習率。指定しない場合は `--learning_rate` の値が使用されます。 +* `--text_encoder_lr1=1e-5` + * **Text Encoder 1 (OpenCLIP ViT-G/14) の LoRA モジュール**に対する学習率。指定しない場合は `--learning_rate` の値が使用されます。U-Net より小さめの値が推奨されます。 +* `--text_encoder_lr2=1e-5` + * **Text Encoder 2 (CLIP ViT-L/14) の LoRA モジュール**に対する学習率。指定しない場合は `--learning_rate` の値が使用されます。U-Net より小さめの値が推奨されます。 +* `--optimizer_type="AdamW8bit"` + * `train_network.py` と共通です。 +* `--lr_scheduler="constant"` + * `train_network.py` と共通です。 +* `--lr_warmup_steps` + * `train_network.py` と共通です。 +* `--max_train_steps`, `--max_train_epochs` + * `train_network.py` と共通です。 +* `--mixed_precision="bf16"` + * 混合精度学習の設定。SDXL では `bf16` または `fp16` の使用が推奨されます。GPU が対応している方を選択してください。VRAM 使用量を削減し、学習速度を向上させます。 +* `--gradient_accumulation_steps=1` + * `train_network.py` と共通です。 +* `--gradient_checkpointing` + * `train_network.py` と共通です。SDXL はメモリ消費が大きいため、有効にすることが推奨されます。 +* `--cache_latents` + * VAE の出力をメモリ(または `--cache_latents_to_disk` 指定時はディスク)にキャッシュします。VAE の計算を省略できるため、VRAM 使用量を削減し、学習を高速化できます。画像に対する Augmentation (`--color_aug`, `--flip_aug`, `--random_crop` 等) が無効になります。SDXL 学習では推奨されるオプションです。 +* `--cache_latents_to_disk` + * `--cache_latents` と併用し、キャッシュ先をディスクにします。データセットを最初に読み込む際に、VAE の出力をディスクにキャッシュします。二回目以降の学習で VAE の計算を省略できるため、学習データの枚数が多い場合に推奨されます。 +* `--cache_text_encoder_outputs` + * Text Encoder の出力をメモリ(または `--cache_text_encoder_outputs_to_disk` 指定時はディスク)にキャッシュします。Text Encoder の計算を省略できるため、VRAM 使用量を削減し、学習を高速化できます。キャプションに対する Augmentation (`--shuffle_caption`, `--caption_dropout_rate` 等) が無効になります。 + * **注意:** このオプションを使用する場合、Text Encoder の LoRA モジュールは学習できません (`--network_train_unet_only` の指定が必須です)。 +* `--cache_text_encoder_outputs_to_disk` + * `--cache_text_encoder_outputs` と併用し、キャッシュ先をディスクにします。 +* `--no_half_vae` + * 混合精度 (`fp16`/`bf16`) 使用時でも VAE を `float32` で動作させます。SDXL の VAE は `float16` で不安定になることがあるため、`fp16` 指定時には有効にしてください。 +* `--clip_skip` + * SDXL では通常使用しません。指定は不要です。 +* `--fused_backward_pass` + * 勾配計算とオプティマイザのステップを融合し、VRAM使用量を削減します。SDXLで利用可能です。(現在 `Adafactor` オプティマイザのみ対応) + +#### その他 + +* `--seed`, `--logging_dir`, `--log_prefix` などは `train_network.py` と共通です。 +
+ +### 3.2. Starting the Training / 学習の開始 + +After setting the necessary arguments, execute the command to start training. The training progress will be displayed on the console. The basic flow is the same as with `train_network.py`. + +
+日本語 +必要な引数を設定し、コマンドを実行すると学習が開始されます。学習の進行状況はコンソールに出力されます。基本的な流れは `train_network.py` と同じです。 +
+ +## 4. Using the Trained Model / 学習済みモデルの利用 + +When training is complete, a LoRA model file (`.safetensors`, etc.) with the name specified by `output_name` will be saved in the directory specified by `output_dir`. + +This file can be used with GUI tools that support SDXL, such as AUTOMATIC1111/stable-diffusion-webui and ComfyUI. + +
+日本語 +学習が完了すると、`output_dir` で指定したディレクトリに、`output_name` で指定した名前の LoRA モデルファイル (`.safetensors` など) が保存されます。 + +このファイルは、AUTOMATIC1111/stable-diffusion-webui 、ComfyUI などの SDXL に対応した GUI ツールで利用できます。 +
+ +## 5. Supplement: Main Differences from `train_network.py` / 補足: `train_network.py` との主な違い + +* **Target Model:** `sdxl_train_network.py` is exclusively for SDXL models. +* **Text Encoder:** Since SDXL has two Text Encoders, there are differences in learning rate specifications (`--text_encoder_lr1`, `--text_encoder_lr2`), etc. +* **Caching Features:** `--cache_text_encoder_outputs` is particularly effective for SDXL and is recommended. +* **Recommended Settings:** Due to high VRAM usage, mixed precision (`bf16` or `fp16`), `gradient_checkpointing`, and caching features (`--cache_latents`, `--cache_text_encoder_outputs`) are recommended. When using `fp16`, it is recommended to run the VAE in `float32` with `--no_half_vae`. + +For other detailed options, please refer to the script's help (`python sdxl_train_network.py --help`) and other documents in the repository. + +
+日本語 +* **対象モデル:** `sdxl_train_network.py` は SDXL モデル専用です。 +* **Text Encoder:** SDXL は 2 つの Text Encoder を持つため、学習率の指定 (`--text_encoder_lr1`, `--text_encoder_lr2`) などが異なります。 +* **キャッシュ機能:** `--cache_text_encoder_outputs` は SDXL で特に効果が高く、推奨されます。 +* **推奨設定:** VRAM 使用量が大きいため、`bf16` または `fp16` の混合精度、`gradient_checkpointing`、キャッシュ機能 (`--cache_latents`, `--cache_text_encoder_outputs`) の利用が推奨されます。`fp16` 指定時は、VAE は `--no_half_vae` で `float32` 動作を推奨します。 + +その他の詳細なオプションについては、スクリプトのヘルプ (`python sdxl_train_network.py --help`) やリポジトリ内の他のドキュメントを参照してください。 +
\ No newline at end of file diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index a05f87f5f..1bd8e4ae1 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -185,7 +185,7 @@ for img_file in img_files: ### Creating a dataset configuration file -You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. +You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. ```toml [general] diff --git a/docs/train_network.md b/docs/train_network.md new file mode 100644 index 000000000..06c08a424 --- /dev/null +++ b/docs/train_network.md @@ -0,0 +1,314 @@ +# How to use the LoRA training script `train_network.py` / LoRA学習スクリプト `train_network.py` の使い方 + +This document explains the basic procedures for training LoRA (Low-Rank Adaptation) models using `train_network.py` included in the `sd-scripts` repository. + +
+日本語 +このドキュメントでは、`sd-scripts` リポジトリに含まれる `train_network.py` を使用して LoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 +
+ +## 1. Introduction / はじめに + +`train_network.py` is a script for training additional networks such as LoRA on Stable Diffusion models (v1.x, v2.x). It allows for additional training on the original model with a low computational cost, enabling the creation of models that reproduce specific characters or art styles. + +This guide focuses on LoRA training and explains the basic configuration items. + +**Prerequisites:** + +* The `sd-scripts` repository has been cloned and the Python environment has been set up. +* The training dataset has been prepared. (For dataset preparation, please refer to [this guide](link/to/dataset/doc)) + +
+日本語 + +`train_network.py` は、Stable Diffusion モデル(v1.x, v2.x)に対して、LoRA などの追加ネットワークを学習させるためのスクリプトです。少ない計算コストで元のモデルに追加学習を行い、特定のキャラクターや画風を再現するモデルを作成できます。 + +このガイドでは、LoRA 学習に焦点を当て、基本的な設定項目を中心に説明します。 + +**前提条件:** + +* `sd-scripts` リポジトリのクローンと Python 環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[こちら](link/to/dataset/doc)を参照してください) +
+ +## 2. Preparation / 準備 + +Before starting training, you will need the following files: + +1. **Training script:** `train_network.py` +2. **Dataset definition file (.toml):** A file in TOML format that describes the configuration of the training dataset. + +### About the Dataset Definition File / データセット定義ファイルについて + +The dataset definition file (`.toml`) contains detailed settings such as the directory of images to use, repetition count, caption settings, resolution buckets (optional), etc. + +For more details on how to write the dataset definition file, please refer to the [Dataset Configuration Guide](link/to/dataset/config/doc). + +In this guide, we will use a file named `my_dataset_config.toml` as an example. + +
+日本語 + +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `train_network.py` +2. **データセット定義ファイル (.toml):** 学習データセットの設定を記述した TOML 形式のファイル。 + +**データセット定義ファイルについて** + +データセット定義ファイル (`.toml`) には、使用する画像のディレクトリ、繰り返し回数、キャプションの設定、解像度バケツ(任意)などの詳細な設定を記述します。 + +データセット定義ファイルの詳しい書き方については、[データセット設定ガイド](link/to/dataset/config/doc)を参照してください。 + +ここでは、例として `my_dataset_config.toml` という名前のファイルを使用することにします。 +
+ +## 3. Running the Training / 学習の実行 + +Training is started by executing `train_network.py` from the terminal. When executing, various training settings are specified as command-line arguments. + +Below is a basic command-line execution example: + +```bash +accelerate launch --num_cpu_threads_per_process 1 train_network.py + --pretrained_model_name_or_path="" + --dataset_config="my_dataset_config.toml" + --output_dir="" + --output_name="my_lora" + --save_model_as=safetensors + --network_module=networks.lora + --network_dim=16 + --network_alpha=1 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --lr_scheduler="constant" + --sdpa + --max_train_epochs=10 + --save_every_n_epochs=1 + --mixed_precision="fp16" + --gradient_checkpointing +``` + +In reality, you need to write this in a single line, but it's shown with line breaks for readability (on Linux or Mac, you can add `\` at the end of each line to break lines). For Windows, either write it in a single line without breaks or add `^` at the end of each line. + +Next, we'll explain the main command-line arguments. + +
+日本語 + +学習は、ターミナルから `train_network.py` を実行することで開始します。実行時には、学習に関する様々な設定をコマンドライン引数として指定します。 + +以下に、基本的なコマンドライン実行例を示します。 + +実際には1行で書く必要がありますが、見やすさのために改行しています(Linux や Mac では `\` を行末に追加することで改行できます)。Windows の場合は、改行せずに1行で書くか、`^` を行末に追加してください。 + +次に、主要なコマンドライン引数について解説します。 +
+ +### 3.1. Main Command-Line Arguments / 主要なコマンドライン引数 + +#### Model Related / モデル関連 + +* `--pretrained_model_name_or_path=""` **[Required]** + * Specifies the Stable Diffusion model to be used as the base for training. You can specify the path to a local `.ckpt` or `.safetensors` file, or a directory containing a Diffusers format model. You can also specify a Hugging Face Hub model ID (e.g., `"stabilityai/stable-diffusion-2-1-base"`). +* `--v2` + * Specify this when the base model is Stable Diffusion v2.x. +* `--v_parameterization` + * Specify this when training with a v-prediction model (such as v2.x 768px models). + +#### Dataset Related / データセット関連 + +* `--dataset_config=""` + * Specifies the path to a `.toml` file describing the dataset configuration. (For details on dataset configuration, see [here](link/to/dataset/config/doc)) + * It's also possible to specify dataset settings from the command line, but using a `.toml` file is recommended as it becomes lengthy. + +#### Output and Save Related / 出力・保存関連 + +* `--output_dir=""` **[Required]** + * Specifies the directory where trained LoRA models, sample images, logs, etc. will be output. +* `--output_name=""` **[Required]** + * Specifies the filename of the trained LoRA model (excluding the extension). +* `--save_model_as="safetensors"` + * Specifies the format for saving the model. You can choose from `safetensors` (recommended), `ckpt`, or `pt`. The default is `safetensors`. +* `--save_every_n_epochs=1` + * Saves the model every specified number of epochs. If not specified, only the final model will be saved. +* `--save_every_n_steps=1000` + * Saves the model every specified number of steps. If both epoch and step saving are specified, both will be saved. + +#### LoRA Parameters / LoRA パラメータ + +* `--network_module=networks.lora` **[Required]** + * Specifies the type of network to train. For LoRA, specify `networks.lora`. +* `--network_dim=16` **[Required]** + * Specifies the rank (dimension) of LoRA. Higher values increase expressiveness but also increase file size and computational cost. Values between 4 and 128 are commonly used. There is no default (module dependent). +* `--network_alpha=1` + * Specifies the alpha value for LoRA. This parameter is related to learning rate scaling. It is generally recommended to set it to about half the value of `network_dim`, but it can also be the same value as `network_dim`. The default is 1. Setting it to the same value as `network_dim` will result in behavior similar to older versions. + +#### Training Parameters / 学習パラメータ + +* `--learning_rate=1e-4` + * Specifies the learning rate. For LoRA training (when alpha value is 1), relatively higher values (e.g., from `1e-4` to `1e-3`) are often used. +* `--unet_lr=1e-4` + * Used to specify a separate learning rate for the LoRA modules in the U-Net part. If not specified, the value of `--learning_rate` is used. +* `--text_encoder_lr=1e-5` + * Used to specify a separate learning rate for the LoRA modules in the Text Encoder part. If not specified, the value of `--learning_rate` is used. A smaller value than that for U-Net is recommended. +* `--optimizer_type="AdamW8bit"` + * Specifies the optimizer to use for training. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion` (requires `lion-pytorch`), `DAdaptation` (requires `dadaptation`), and `Adafactor`. `AdamW8bit` is memory-efficient and widely used. +* `--lr_scheduler="constant"` + * Specifies the learning rate scheduler. This is the method for changing the learning rate as training progresses. Options include `constant` (no change), `cosine` (cosine curve), `linear` (linear decay), `constant_with_warmup` (constant with warmup), and `cosine_with_restarts`. `constant`, `cosine`, and `constant_with_warmup` are commonly used. +* `--lr_warmup_steps=500` + * Specifies the number of warmup steps for the learning rate scheduler. This is the period during which the learning rate gradually increases at the start of training. Valid when the `lr_scheduler` supports warmup. +* `--max_train_steps=10000` + * Specifies the total number of training steps. If `max_train_epochs` is specified, that takes precedence. +* `--max_train_epochs=12` + * Specifies the number of training epochs. If this is specified, `max_train_steps` is ignored. +* `--sdpa` + * Uses Scaled Dot-Product Attention. This can reduce memory usage and improve training speed for LoRA training. +* `--mixed_precision="fp16"` + * Specifies the mixed precision training setting. Options are `no` (disabled), `fp16` (half precision), and `bf16` (bfloat16). If your GPU supports it, specifying `fp16` or `bf16` can improve training speed and reduce memory usage. +* `--gradient_accumulation_steps=1` + * Specifies the number of steps to accumulate gradients. This effectively increases the batch size to `train_batch_size * gradient_accumulation_steps`. Set a larger value if GPU memory is insufficient. Usually `1` is fine. + +#### Others / その他 + +* `--seed=42` + * Specifies the random seed. Set this if you want to ensure reproducibility of the training. +* `--logging_dir=""` + * Specifies the directory to output logs for TensorBoard, etc. If not specified, logs will not be output. +* `--log_prefix=""` + * Specifies the prefix for the subdirectory name created within `logging_dir`. +* `--gradient_checkpointing` + * Enables Gradient Checkpointing. This can significantly reduce memory usage but slightly decreases training speed. Useful when memory is limited. +* `--clip_skip=1` + * Specifies how many layers to skip from the last layer of the Text Encoder. Specifying `2` will use the output from the second-to-last layer. `None` or `1` means no skip (uses the last layer). Check the recommended value for the model you are training. + +
+日本語 + +#### モデル関連 + +* `--pretrained_model_name_or_path="<モデルのパス>"` **[必須]** + * 学習のベースとなる Stable Diffusion モデルを指定します。ローカルの `.ckpt` または `.safetensors` ファイルのパス、あるいは Diffusers 形式モデルのディレクトリを指定できます。Hugging Face Hub のモデル ID (例: `"stabilityai/stable-diffusion-2-1-base"`) も指定可能です。 +* `--v2` + * ベースモデルが Stable Diffusion v2.x の場合に指定します。 +* `--v_parameterization` + * v-prediction モデル(v2.x の 768px モデルなど)で学習する場合に指定します。 + +#### データセット関連 + +* `--dataset_config="<設定ファイルのパス>"` + * データセット設定を記述した `.toml` ファイルのパスを指定します。(データセット設定の詳細は[こちら](link/to/dataset/config/doc)) + * コマンドラインからデータセット設定を指定することも可能ですが、長くなるため `.toml` ファイルを使用することを推奨します。 + +#### 出力・保存関連 + +* `--output_dir="<出力先ディレクトリ>"` **[必須]** + * 学習済み LoRA モデルやサンプル画像、ログなどが出力されるディレクトリを指定します。 +* `--output_name="<出力ファイル名>"` **[必須]** + * 学習済み LoRA モデルのファイル名(拡張子を除く)を指定します。 +* `--save_model_as="safetensors"` + * モデルの保存形式を指定します。`safetensors` (推奨), `ckpt`, `pt` から選択できます。デフォルトは `safetensors` です。 +* `--save_every_n_epochs=1` + * 指定したエポックごとにモデルを保存します。省略するとエポックごとの保存は行われません(最終モデルのみ保存)。 +* `--save_every_n_steps=1000` + * 指定したステップごとにモデルを保存します。エポック指定 (`save_every_n_epochs`) と同時に指定された場合、両方とも保存されます。 + +#### LoRA パラメータ + +* `--network_module=networks.lora` **[必須]** + * 学習するネットワークの種別を指定します。LoRA の場合は `networks.lora` を指定します。 +* `--network_dim=16` **[必須]** + * LoRA のランク (rank / 次元数) を指定します。値が大きいほど表現力は増しますが、ファイルサイズと計算コストが増加します。一般的には 4〜128 程度の値が使われます。デフォルトは指定されていません(モジュール依存)。 +* `--network_alpha=1` + * LoRA のアルファ値 (alpha) を指定します。学習率のスケーリングに関係するパラメータで、一般的には `network_dim` の半分程度の値を指定することが推奨されますが、`network_dim` と同じ値を指定する場合もあります。デフォルトは 1 です。`network_dim` と同じ値に設定すると、旧バージョンと同様の挙動になります。 + +#### 学習パラメータ + +* `--learning_rate=1e-4` + * 学習率を指定します。LoRA 学習では(アルファ値が1の場合)比較的高めの値(例: `1e-4`から`1e-3`)が使われることが多いです。 +* `--unet_lr=1e-4` + * U-Net 部分の LoRA モジュールに対する学習率を個別に指定する場合に使用します。指定しない場合は `--learning_rate` の値が使用されます。 +* `--text_encoder_lr=1e-5` + * Text Encoder 部分の LoRA モジュールに対する学習率を個別に指定する場合に使用します。指定しない場合は `--learning_rate` の値が使用されます。U-Net よりも小さめの値が推奨されます。 +* `--optimizer_type="AdamW8bit"` + * 学習に使用するオプティマイザを指定します。`AdamW8bit` (要 `bitsandbytes`), `AdamW`, `Lion` (要 `lion-pytorch`), `DAdaptation` (要 `dadaptation`), `Adafactor` などが選択可能です。`AdamW8bit` はメモリ効率が良く、広く使われています。 +* `--lr_scheduler="constant"` + * 学習率スケジューラを指定します。学習の進行に合わせて学習率を変化させる方法です。`constant` (変化なし), `cosine` (コサインカーブ), `linear` (線形減衰), `constant_with_warmup` (ウォームアップ付き定数), `cosine_with_restarts` などが選択可能です。`constant`や`cosine` 、 `constant_with_warmup` がよく使われます。 +* `--lr_warmup_steps=500` + * 学習率スケジューラのウォームアップステップ数を指定します。学習開始時に学習率を徐々に上げていく期間です。`lr_scheduler` がウォームアップをサポートする場合に有効です。 +* `--max_train_steps=10000` + * 学習の総ステップ数を指定します。`max_train_epochs` が指定されている場合はそちらが優先されます。 +* `--max_train_epochs=12` + * 学習のエポック数を指定します。これを指定すると `max_train_steps` は無視されます。 +* `--sdpa` + * Scaled Dot-Product Attention を使用します。LoRA の学習において、メモリ使用量を削減し、学習速度を向上させることができます。 +* `--mixed_precision="fp16"` + * 混合精度学習の設定を指定します。`no` (無効), `fp16` (半精度), `bf16` (bfloat16) から選択できます。GPU が対応している場合は `fp16` または `bf16` を指定することで、学習速度の向上とメモリ使用量の削減が期待できます。 +* `--gradient_accumulation_steps=1` + * 勾配を累積するステップ数を指定します。実質的なバッチサイズを `train_batch_size * gradient_accumulation_steps` に増やす効果があります。GPU メモリが足りない場合に大きな値を設定します。通常は `1` で問題ありません。 + +#### その他 + +* `--seed=42` + * 乱数シードを指定します。学習の再現性を確保したい場合に設定します。 +* `--logging_dir="<ログディレクトリ>"` + * TensorBoard などのログを出力するディレクトリを指定します。指定しない場合、ログは出力されません。 +* `--log_prefix="<プレフィックス>"` + * `logging_dir` 内に作成されるサブディレクトリ名の接頭辞を指定します。 +* `--gradient_checkpointing` + * Gradient Checkpointing を有効にします。メモリ使用量を大幅に削減できますが、学習速度は若干低下します。メモリが厳しい場合に有効です。 +* `--clip_skip=1` + * Text Encoder の最後の層から数えて何層スキップするかを指定します。`2` を指定すると最後から 2 層目の出力を使用します。`None` または `1` はスキップなし(最後の層を使用)を意味します。学習対象のモデルの推奨する値を確認してください。 +
+ +### 3.2. Starting the Training / 学習の開始 + +After setting the necessary arguments and executing the command, training will begin. The progress of the training will be output to the console. If `logging_dir` is specified, you can visually check the training status (loss, learning rate, etc.) with TensorBoard. + +```bash +tensorboard --logdir +``` + +
+日本語 + +必要な引数を設定し、コマンドを実行すると学習が開始されます。学習の進行状況はコンソールに出力されます。`logging_dir` を指定した場合は、TensorBoard などで学習状況(損失や学習率など)を視覚的に確認できます。 +
+ +## 4. Using the Trained Model / 学習済みモデルの利用 + +Once training is complete, a LoRA model file (`.safetensors` or `.ckpt`) with the name specified by `output_name` will be saved in the directory specified by `output_dir`. + +This file can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc. + +
+日本語 + +学習が完了すると、`output_dir` で指定したディレクトリに、`output_name` で指定した名前の LoRA モデルファイル (`.safetensors` または `.ckpt`) が保存されます。 + +このファイルは、AUTOMATIC1111/stable-diffusion-webui 、ComfyUI などの GUI ツールで利用できます。 +
+ +## 5. Other Features / その他の機能 + +`train_network.py` has many other options not introduced here. + +* Sample image generation (`--sample_prompts`, `--sample_every_n_steps`, etc.) +* More detailed optimizer settings (`--optimizer_args`, etc.) +* Caption preprocessing (`--shuffle_caption`, `--keep_tokens`, etc.) +* Additional network settings (`--network_args`, etc.) + +For these features, please refer to the script's help (`python train_network.py --help`) or other documents in the repository. + +
+日本語 + +`train_network.py` には、ここで紹介した以外にも多くのオプションがあります。 + +* サンプル画像の生成 (`--sample_prompts`, `--sample_every_n_steps` など) +* より詳細なオプティマイザ設定 (`--optimizer_args` など) +* キャプションの前処理 (`--shuffle_caption`, `--keep_tokens` など) +* ネットワークの追加設定 (`--network_args` など) + +これらの機能については、スクリプトのヘルプ (`python train_network.py --help`) やリポジトリ内の他のドキュメントを参照してください。 +
\ No newline at end of file diff --git a/docs/train_network_advanced.md b/docs/train_network_advanced.md new file mode 100644 index 000000000..c1fd86a22 --- /dev/null +++ b/docs/train_network_advanced.md @@ -0,0 +1,515 @@ +# Advanced Settings: Detailed Guide for SDXL LoRA Training Script `sdxl_train_network.py` / 高度な設定: SDXL LoRA学習スクリプト `sdxl_train_network.py` 詳細ガイド + +This document describes the advanced options available when training LoRA models for SDXL (Stable Diffusion XL) with `sdxl_train_network.py` in the `sd-scripts` repository. For the basics, please read [How to Use the LoRA Training Script `train_network.py`](train_network.md) and [How to Use the SDXL LoRA Training Script `sdxl_train_network.py`](sdxl_train_network.md). + +This guide targets experienced users who want to fine tune settings in detail. + +**Prerequisites:** + +* You have cloned the `sd-scripts` repository and prepared a Python environment. +* A training dataset and its `.toml` configuration are ready (see the dataset configuration guide). +* You are familiar with running basic LoRA training commands. + +## 1. Command Line Options / コマンドライン引数 詳細解説 + +`sdxl_train_network.py` inherits the functionality of `train_network.py` and adds SDXL-specific features. Major options are grouped and explained below. For common arguments, see the other guides mentioned above. + +### 1.1. Model Loading + +* `--pretrained_model_name_or_path=\"\"` **[Required]**: specify the base SDXL model. Supports a Hugging Face model ID, a local Diffusers directory or a `.safetensors` file. +* `--vae=\"\"`: optionally use a different VAE. Specify when using a VAE other than the one included in the SDXL model. Can specify `.ckpt` or `.safetensors` files. +* `--no_half_vae`: keep the VAE in float32 even with fp16/bf16 training. The VAE for SDXL can become unstable with `float16`, so it is recommended to enable this when `fp16` is specified. Usually unnecessary for `bf16`. +* `--fp8_base` / `--fp8_base_unet`: **Experimental**: load the base model (U-Net, Text Encoder) or just the U-Net in FP8 to reduce VRAM (requires PyTorch 2.1+). For details, refer to the relevant section in TODO add document later (this is an SD3 explanation but also applies to SDXL). + +### 1.2. Dataset Settings + +* `--dataset_config=\"\"`: specify a `.toml` dataset config. High resolution data and aspect ratio buckets (specify `enable_bucket = true` in `.toml`) are common for SDXL. The resolution steps for aspect ratio buckets (`bucket_reso_steps`) must be multiples of 32 for SDXL. For details on writing `.toml` files, refer to the [Dataset Configuration Guide](link/to/dataset/config/doc). + +### 1.3. Output and Saving + +Options match `train_network.py`: + +* `--output_dir`, `--output_name` (both required) +* `--save_model_as` (recommended `safetensors`), `ckpt`, `pt`, `diffusers`, `diffusers_safetensors` +* `--save_precision=\"fp16\"`, `\"bf16\"`, `\"float\"`: Specifies the precision for saving the model. If not specified, the model is saved with the training precision (`fp16`, `bf16`, etc.). +* `--save_every_n_epochs=N`, `--save_every_n_steps=N`: Saves the model every N epochs/steps. +* `--save_last_n_epochs=M`, `--save_last_n_steps=M`: When saving at every epoch/step, only the latest M files are kept, and older ones are deleted. +* `--save_state`, `--save_state_on_train_end`: Saves the training state (`state`), including Optimizer status, etc., when saving the model or at the end of training. Required for resuming training with the `--resume` option. +* `--save_last_n_epochs_state=M`, `--save_last_n_steps_state=M`: Limits the number of saved `state` files to M. Overrides the `--save_last_n_epochs/steps` specification. +* `--no_metadata`: Does not save metadata to the output model. +* `--save_state_to_huggingface` and related options (e.g., `--huggingface_repo_id`): Options related to uploading models and states to Hugging Face Hub. See TODO add document for details. + +### 1.4. Network Parameters (LoRA) + +* `--network_module=networks.lora` **[Required]** +* `--network_dim=N` **[Required]**: Specifies the rank (dimensionality) of LoRA. For SDXL, values like 32 or 64 are often tried, but adjustment is necessary depending on the dataset and purpose. +* `--network_alpha=M`: LoRA alpha value. Generally around half of `network_dim` or the same value as `network_dim`. Default is 1. +* `--network_dropout=P`: Dropout rate (0.0-1.0) within LoRA modules. Can be effective in suppressing overfitting. Default is None (no dropout). +* `--network_args ...`: Allows advanced settings by specifying additional arguments to the network module in `key=value` format. For LoRA, the following advanced settings are available: + * **Block-wise dimensions/alphas:** + * Allows specifying different `dim` and `alpha` for each block of the U-Net. This enables adjustments to strengthen or weaken the influence of specific layers. + * `block_dims`: Comma-separated dims for Linear and Conv2d 1x1 layers in U-Net (23 values for SDXL). + * `block_alphas`: Comma-separated alpha values corresponding to the above. + * `conv_block_dims`: Comma-separated dims for Conv2d 3x3 layers in U-Net. + * `conv_block_alphas`: Comma-separated alpha values corresponding to the above. + * Blocks not specified will use values from `--network_dim`/`--network_alpha` or `--conv_dim`/`--conv_alpha` (if they exist). + * For details, refer to [Block-wise learning rate for LoRA](train_network.md#lora-の階層別学習率) (in train_network.md, applicable to SDXL) and the implementation ([lora.py](lora.py)). + * **LoRA+:** + * `loraplus_lr_ratio=R`: Sets the learning rate of LoRA's upward weights (UP) to R times the learning rate of downward weights (DOWN). Expected to improve learning speed. Paper recommends 16. + * `loraplus_unet_lr_ratio=RU`: Specifies the LoRA+ learning rate ratio for the U-Net part individually. + * `loraplus_text_encoder_lr_ratio=RT`: Specifies the LoRA+ learning rate ratio for the Text Encoder part individually (multiplied by the learning rates specified with `--text_encoder_lr1`, `--text_encoder_lr2`). + * For details, refer to [README](../README.md#jan-17-2025--2025-01-17-version-090) and the implementation ([lora.py](lora.py)). +* `--network_train_unet_only`: Trains only the LoRA modules of the U-Net. Specify this if not training Text Encoders. Required when using `--cache_text_encoder_outputs`. +* `--network_train_text_encoder_only`: Trains only the LoRA modules of the Text Encoders. Specify this if not training the U-Net. +* `--network_weights=\"\"`: Starts training by loading pre-trained LoRA weights. Used for fine-tuning or resuming training. The difference from `--resume` is that this option only loads LoRA module weights, while `--resume` also restores Optimizer state, step count, etc. +* `--dim_from_weights`: Automatically reads the LoRA dimension (`dim`) from the weight file specified by `--network_weights`. Specification of `--network_dim` becomes unnecessary. + +### 1.5. Training Parameters + +* `--learning_rate=LR`: Sets the overall learning rate. This becomes the default value for each module (`unet_lr`, `text_encoder_lr1`, `text_encoder_lr2`). Values like `1e-3` or `1e-4` are often tried. +* `--unet_lr=LR_U`: Learning rate for the LoRA module of the U-Net part. +* `--text_encoder_lr1=LR_TE1`: Learning rate for the LoRA module of Text Encoder 1 (OpenCLIP ViT-G/14). Usually, a smaller value than U-Net (e.g., `1e-5`, `2e-5`) is recommended. +* `--text_encoder_lr2=LR_TE2`: Learning rate for the LoRA module of Text Encoder 2 (CLIP ViT-L/14). Usually, a smaller value than U-Net (e.g., `1e-5`, `2e-5`) is recommended. +* `--optimizer_type=\"...\"`: Specifies the optimizer to use. Options include `AdamW8bit` (memory-efficient, common), `Adafactor` (even more memory-efficient, proven in SDXL full model training), `Lion`, `DAdaptation`, `Prodigy`, etc. Each optimizer may require additional arguments (see `--optimizer_args`). `AdamW8bit` or `PagedAdamW8bit` (requires `bitsandbytes`) are common. `Adafactor` is memory-efficient but slightly complex to configure (relative step (`relative_step=True`) recommended, `adafactor` learning rate scheduler recommended). `DAdaptation`, `Prodigy` have automatic learning rate adjustment but cannot be used with LoRA+. Specify a learning rate around `1.0`. For details, see the `get_optimizer` function in [train_util.py](train_util.py). +* `--optimizer_args ...`: Specifies additional arguments to the optimizer in `key=value` format (e.g., `\"weight_decay=0.01\"` `\"betas=0.9,0.999\"`). +* `--lr_scheduler=\"...\"`: Specifies the learning rate scheduler. Options include `constant` (no change), `cosine` (cosine curve), `linear` (linear decay), `constant_with_warmup` (constant with warmup), `cosine_with_restarts`, etc. `constant`, `cosine`, and `constant_with_warmup` are commonly used. Some schedulers require additional arguments (see `--lr_scheduler_args`). If using optimizers with auto LR adjustment like `DAdaptation` or `Prodigy`, a scheduler is not needed (`constant` should be specified). +* `--lr_warmup_steps=N`: Number of warmup steps for the learning rate scheduler. The learning rate gradually increases during this period at the start of training. If N < 1, it's interpreted as a fraction of total steps. +* `--lr_scheduler_num_cycles=N` / `--lr_scheduler_power=P`: Parameters for specific schedulers (`cosine_with_restarts`, `polynomial`). +* `--max_train_steps=N` / `--max_train_epochs=N`: Specifies the total number of training steps or epochs. Epoch specification takes precedence. +* `--mixed_precision=\"bf16\"` / `\"fp16\"` / `\"no\"`: Mixed precision training settings. For SDXL, using `bf16` (if GPU supports it) or `fp16` is strongly recommended. Reduces VRAM usage and improves training speed. +* `--full_fp16` / `--full_bf16`: Performs gradient calculations entirely in half-precision/bf16. Can further reduce VRAM usage but may affect training stability. Use if VRAM is critically low. +* `--gradient_accumulation_steps=N`: Accumulates gradients for N steps before updating the optimizer. Effectively increases the batch size to `train_batch_size * N`, achieving the effect of a larger batch size with less VRAM. Default is 1. +* `--max_grad_norm=N`: Gradient clipping threshold. Clips gradients if their norm exceeds N. Default is 1.0. `0` disables it. +* `--gradient_checkpointing`: Significantly reduces memory usage but slightly decreases training speed. Recommended for SDXL due to high memory consumption. +* `--fused_backward_pass`: **Experimental**: Fuses gradient calculation and optimizer steps to reduce VRAM usage. Available for SDXL. Currently only supports `Adafactor` optimizer. Cannot be used with Gradient Accumulation. +* `--resume=\"\"`: Resumes training from a saved state (saved with `--save_state`). Restores optimizer state, step count, etc. + +### 1.6. Caching + +Caching is effective for SDXL due to its high computational cost. + +* `--cache_latents`: Caches VAE outputs (latents) in memory. Skips VAE computation, reducing VRAM usage and speeding up training. **Note:** Image augmentations (`color_aug`, `flip_aug`, `random_crop`, etc.) will be disabled. +* `--cache_latents_to_disk`: Used with `--cache_latents` to cache to disk. Particularly effective for large datasets or multiple training runs. Caches are generated on disk during the first run and loaded from there on subsequent runs. +* `--cache_text_encoder_outputs`: Caches Text Encoder outputs in memory. Skips Text Encoder computation, reducing VRAM usage and speeding up training. **Note:** Caption augmentations (`shuffle_caption`, `caption_dropout_rate`, etc.) will be disabled. **Also, when using this option, Text Encoder LoRA modules cannot be trained (requires `--network_train_unet_only`).** +* `--cache_text_encoder_outputs_to_disk`: Used with `--cache_text_encoder_outputs` to cache to disk. +* `--skip_cache_check`: Skips validation of cache file contents. File existence is checked, and if not found, caches are generated. Usually not needed unless intentionally re-caching for debugging, etc. + +### 1.7. Sample Image Generation + +Basic options are common with `train_network.py`. + +* `--sample_every_n_steps=N` / `--sample_every_n_epochs=N`: Generates sample images every N steps/epochs. +* `--sample_at_first`: Generates sample images before training starts. +* `--sample_prompts=\"\"`: Specifies a file (`.txt`, `.toml`, `.json`) containing prompts for sample image generation. Format follows [gen_img_diffusers.py](gen_img_diffusers.py). See [documentation](gen_img_README-ja.md) for details. +* `--sample_sampler=\"...\"`: Specifies the sampler (scheduler) for sample image generation. `euler_a`, `dpm++_2m_karras`, etc., are common. See `--help` for choices. + +### 1.8. Logging & Tracking + +* `--logging_dir=\"\"`: Specifies the directory for TensorBoard and other logs. If not specified, logs are not output. +* `--log_with=\"tensorboard\"` / `\"wandb\"` / `\"all\"`: Specifies the logging tool to use. If using `wandb`, `pip install wandb` is required. +* `--log_prefix=\"\"`: Specifies the prefix for subdirectory names created within `logging_dir`. +* `--wandb_api_key=\"\"` / `--wandb_run_name=\"\"`: Options for Weights & Biases (wandb). +* `--log_tracker_name` / `--log_tracker_config`: Advanced tracker configuration options. Usually not needed. +* `--log_config`: Logs the training configuration used (excluding some sensitive information) at the start of training. Helps ensure reproducibility. + +### 1.9. Regularization and Advanced Techniques + +* `--noise_offset=N`: Enables noise offset and specifies its value. Expected to improve bias in image brightness and contrast. Recommended to enable as SDXL base models are trained with this (e.g., 0.0357). Original technical explanation [here](https://www.crosslabs.org/blog/diffusion-with-offset-noise). +* `--noise_offset_random_strength`: Randomly varies noise offset strength between 0 and the specified value. +* `--adaptive_noise_scale=N`: Adjusts noise offset based on the mean absolute value of latents. Used with `--noise_offset`. +* `--multires_noise_iterations=N` / `--multires_noise_discount=D`: Enables multi-resolution noise. Adding noise of different frequency components is expected to improve detail reproduction. Specify iteration count N (around 6-10) and discount rate D (around 0.3). Technical explanation [here](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2). +* `--ip_noise_gamma=G` / `--ip_noise_gamma_random_strength`: Enables Input Perturbation Noise. Adds small noise to input (latents) for regularization. Specify Gamma value (around 0.1). Strength can be randomized with `random_strength`. +* `--min_snr_gamma=N`: Applies Min-SNR Weighting Strategy. Adjusts loss weights for timesteps with high noise in early training to stabilize learning. `N=5` etc. are used. +* `--scale_v_pred_loss_like_noise_pred`: In v-prediction models, scales v-prediction loss similarly to noise prediction loss. **Not typically used for SDXL** as it's not a v-prediction model. +* `--v_pred_like_loss=N`: Adds v-prediction-like loss to noise prediction models. `N` specifies its weight. **Not typically used for SDXL**. +* `--debiased_estimation_loss`: Calculates loss using Debiased Estimation. Similar purpose to Min-SNR but a different approach. +* `--loss_type=\"l1\"` / `\"l2\"` / `\"huber\"` / `\"smooth_l1\"`: Specifies the loss function. Default is `l2` (MSE). `huber` and `smooth_l1` are robust to outliers. +* `--huber_schedule=\"constant\"` / `\"exponential\"` / `\"snr\"`: Scheduling method when using `huber` or `smooth_l1` loss. `snr` is recommended. +* `--huber_c=C` / `--huber_scale=S`: Parameters for `huber` or `smooth_l1` loss. +* `--masked_loss`: Limits loss calculation area based on a mask image. Requires specifying mask images (black and white) in `conditioning_data_dir` in dataset settings. See [About Masked Loss](masked_loss_README.md) for details. + +### 1.10. Distributed Training and Other Training Related Options + +* `--seed=N`: Specifies the random seed. Set this to ensure training reproducibility. +* `--max_token_length=N` (`75`, `150`, `225`): Maximum token length processed by Text Encoders. For SDXL, typically `75` (default), `150`, or `225`. Longer lengths can handle more complex prompts but increase VRAM usage. +* `--clip_skip=N`: Uses the output from N layers skipped from the final layer of Text Encoders. **Not typically used for SDXL**. +* `--lowram` / `--highvram`: Options for memory usage optimization. `--lowram` is for environments like Colab where RAM < VRAM, `--highvram` is for environments with ample VRAM. +* `--persistent_data_loader_workers` / `--max_data_loader_n_workers=N`: Settings for DataLoader worker processes. Affects wait time between epochs and memory usage. +* `--config_file=""` / `--output_config`: Options to use/output a `.toml` file instead of command line arguments. +* **Accelerate/DeepSpeed related:** (`--ddp_timeout`, `--ddp_gradient_as_bucket_view`, `--ddp_static_graph`): Detailed settings for distributed training. Accelerate settings (`accelerate config`) are usually sufficient. DeepSpeed requires separate configuration. +* `--initial_epoch=` – Sets the initial epoch number. `1` means first epoch (same as not specifying). Note: `initial_epoch`/`initial_step` doesn't affect the lr scheduler, which means lr scheduler will start from 0 without `--resume`. +* `--initial_step=` – Sets the initial step number including all epochs. `0` means first step (same as not specifying). Overwrites `initial_epoch`. +* `--skip_until_initial_step` – Skips training until `initial_step` is reached. + +### 1.11. Console and Logging / コンソールとログ + +* `--console_log_level`: Sets the logging level for the console output. Choose from `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`. +* `--console_log_file`: Redirects console logs to a specified file. +* `--console_log_simple`: Enables a simpler log format. + +### 1.12. Hugging Face Hub Integration / Hugging Face Hub 連携 + +* `--huggingface_repo_id`: The repository name on Hugging Face Hub to upload the model to (e.g., `your-username/your-model`). +* `--huggingface_repo_type`: The type of repository on Hugging Face Hub. Usually `model`. +* `--huggingface_path_in_repo`: The path within the repository to upload files to. +* `--huggingface_token`: Your Hugging Face Hub authentication token. +* `--huggingface_repo_visibility`: Sets the visibility of the repository (`public` or `private`). +* `--resume_from_huggingface`: Resumes training from a state saved on Hugging Face Hub. +* `--async_upload`: Enables asynchronous uploading of models to the Hub, preventing it from blocking the training process. +* `--save_n_epoch_ratio`: Saves the model at a certain ratio of total epochs. For example, `5` will save at least 5 checkpoints throughout the training. + +### 1.13. Advanced Attention Settings / 高度なAttention設定 + +* `--mem_eff_attn`: Use memory-efficient attention mechanism. This is an older implementation and `sdpa` or `xformers` are generally recommended. +* `--xformers`: Use xformers library for memory-efficient attention. Requires `pip install xformers`. + +### 1.14. Advanced LR Scheduler Settings / 高度な学習率スケジューラ設定 + +* `--lr_scheduler_type`: Specifies a custom scheduler module. +* `--lr_scheduler_args`: Provides additional arguments to the custom scheduler (e.g., `"T_max=100"`). +* `--lr_decay_steps`: Sets the number of steps for the learning rate to decay. +* `--lr_scheduler_timescale`: The timescale for the inverse square root scheduler. +* `--lr_scheduler_min_lr_ratio`: Sets the minimum learning rate as a ratio of the initial learning rate for certain schedulers. + +### 1.15. Differential Learning with LoRA / LoRAの差分学習 + +This technique involves merging a pre-trained LoRA into the base model before starting a new training session. This is useful for fine-tuning an existing LoRA or for learning the 'difference' from it. + +* `--base_weights`: Path to one or more LoRA weight files to be merged into the base model before training begins. +* `--base_weights_multiplier`: A multiplier for the weights of the LoRA specified by `--base_weights`. You can specify multiple values if you provide multiple weights. + +### 1.16. Other Miscellaneous Options / その他のオプション + +* `--tokenizer_cache_dir`: Specifies a directory to cache the tokenizer, which is useful for offline training. +* `--scale_weight_norms`: Scales the weight norms of the LoRA modules. This can help prevent overfitting by controlling the magnitude of the weights. A value of `1.0` is a good starting point. +* `--disable_mmap_load_safetensors`: Disables memory-mapped loading for `.safetensors` files. This can speed up model loading in some environments like WSL. + +## 2. Other Tips / その他のTips + + +* **VRAM Usage:** SDXL LoRA training requires a lot of VRAM. Even with 24GB VRAM, you might run out of memory depending on settings. Reduce VRAM usage with these settings: + * `--mixed_precision=\"bf16\"` or `\"fp16\"` (essential) + * `--gradient_checkpointing` (strongly recommended) + * `--cache_latents` / `--cache_text_encoder_outputs` (highly effective, with limitations) + * `--optimizer_type=\"AdamW8bit\"` or `\"Adafactor\"` + * Increase `--gradient_accumulation_steps` (reduce batch size) + * `--full_fp16` / `--full_bf16` (be mindful of stability) + * `--fp8_base` / `--fp8_base_unet` (experimental) + * `--fused_backward_pass` (Adafactor only, experimental) +* **Learning Rate:** Appropriate learning rates for SDXL LoRA depend on the dataset and `network_dim`/`alpha`. Starting around `1e-4` ~ `4e-5` (U-Net), `1e-5` ~ `2e-5` (Text Encoders) is common. +* **Training Time:** Training takes time due to high-resolution data and the size of the SDXL model. Using caching features and appropriate hardware is important. +* **Troubleshooting:** + * **NaN Loss:** Learning rate might be too high, mixed precision settings incorrect (e.g., `--no_half_vae` not specified with `fp16`), or dataset issues. + * **Out of Memory (OOM):** Try the VRAM reduction measures listed above. + * **Training not progressing:** Learning rate might be too low, optimizer/scheduler settings incorrect, or dataset issues. + +## 3. Conclusion / おわりに + +`sdxl_train_network.py` offers many options to customize SDXL LoRA training. Refer to `--help`, other documents and the source code for further details. + +
+日本語 + +# 高度な設定: SDXL LoRA学習スクリプト `sdxl_train_network.py` 詳細ガイド + +このドキュメントでは、`sd-scripts` リポジトリに含まれる `sdxl_train_network.py` を使用した、SDXL (Stable Diffusion XL) モデルに対する LoRA (Low-Rank Adaptation) モデル学習の高度な設定オプションについて解説します。 + +基本的な使い方については、以下のドキュメントを参照してください。 + +* [LoRA学習スクリプト `train_network.py` の使い方](train_network.md) +* [SDXL LoRA学習スクリプト `sdxl_train_network.py` の使い方](sdxl_train_network.md) + +このガイドは、基本的なLoRA学習の経験があり、より詳細な設定や高度な機能を試したい熟練した利用者を対象としています。 + +**前提条件:** + +* `sd-scripts` リポジトリのクローンと Python 環境のセットアップが完了していること。 +* 学習用データセットの準備と設定(`.toml`ファイル)が完了していること。([データセット設定ガイド](link/to/dataset/config/doc)参照) +* 基本的なLoRA学習のコマンドライン実行経験があること。 + +## 1. コマンドライン引数 詳細解説 + +`sdxl_train_network.py` は `train_network.py` の機能を継承しつつ、SDXL特有の機能を追加しています。ここでは、SDXL LoRA学習に関連する主要なコマンドライン引数について、機能別に分類して詳細に解説します。 + +基本的な引数については、[LoRA学習スクリプト `train_network.py` の使い方](train_network.md#31-主要なコマンドライン引数) および [SDXL LoRA学習スクリプト `sdxl_train_network.py` の使い方](sdxl_train_network.md#31-主要なコマンドライン引数(差分)) を参照してください。 + +### 1.1. モデル読み込み関連 + +* `--pretrained_model_name_or_path="<モデルパス>"` **[必須]** + * 学習のベースとなる **SDXLモデル** を指定します。Hugging Face HubのモデルID、ローカルのDiffusers形式モデルディレクトリ、または`.safetensors`ファイルを指定できます。 + * 詳細は[基本ガイド](sdxl_train_network.md#モデル関連)を参照してください。 +* `--vae=""` + * オプションで、学習に使用するVAEを指定します。SDXLモデルに含まれるVAE以外を使用する場合に指定します。`.ckpt`または`.safetensors`ファイルを指定できます。 +* `--no_half_vae` + * 混合精度(`fp16`/`bf16`)使用時でもVAEを`float32`で動作させます。SDXLのVAEは`float16`で不安定になることがあるため、`fp16`指定時には有効にすることが推奨されます。`bf16`では通常不要です。 +* `--fp8_base` / `--fp8_base_unet` + * **実験的機能:** ベースモデル(U-Net, Text Encoder)またはU-NetのみをFP8で読み込み、VRAM使用量を削減します。PyTorch 2.1以上が必要です。詳細は TODO 後でドキュメントを追加 の関連セクションを参照してください (SD3の説明ですがSDXLにも適用されます)。 + +### 1.2. データセット設定関連 + +* `--dataset_config="<設定ファイルのパス>"` + * データセットの設定を記述した`.toml`ファイルを指定します。SDXLでは高解像度データとバケツ機能(`.toml` で `enable_bucket = true` を指定)の利用が一般的です。 + * `.toml`ファイルの書き方の詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください。 + * アスペクト比バケツの解像度ステップ(`bucket_reso_steps`)は、SDXLでは32の倍数とする必要があります。 + +### 1.3. 出力・保存関連 + +基本的なオプションは `train_network.py` と共通です。 + +* `--output_dir="<出力先ディレクトリ>"` **[必須]** +* `--output_name="<出力ファイル名>"` **[必須]** +* `--save_model_as="safetensors"` (推奨), `ckpt`, `pt`, `diffusers`, `diffusers_safetensors` +* `--save_precision="fp16"`, `"bf16"`, `"float"` + * モデルの保存精度を指定します。未指定時は学習時の精度(`fp16`, `bf16`等)で保存されます。 +* `--save_every_n_epochs=N` / `--save_every_n_steps=N` + * Nエポック/ステップごとにモデルを保存します。 +* `--save_last_n_epochs=M` / `--save_last_n_steps=M` + * エポック/ステップごとに保存する際、最新のM個のみを保持し、古いものは削除します。 +* `--save_state` / `--save_state_on_train_end` + * モデル保存時/学習終了時に、Optimizerの状態などを含む学習状態(`state`)を保存します。`--resume`オプションでの学習再開に必要です。 +* `--save_last_n_epochs_state=M` / `--save_last_n_steps_state=M` + * `state`の保存数をM個に制限します。`--save_last_n_epochs/steps`の指定を上書きします。 +* `--no_metadata` + * 出力モデルにメタデータを保存しません。 +* `--save_state_to_huggingface` / `--huggingface_repo_id` など + * Hugging Face Hubへのモデルやstateのアップロード関連オプション。詳細は TODO ドキュメントを追加 を参照してください。 + +### 1.4. ネットワークパラメータ (LoRA) + +基本的なオプションは `train_network.py` と共通です。 + +* `--network_module=networks.lora` **[必須]** +* `--network_dim=N` **[必須]** + * LoRAのランク (次元数) を指定します。SDXLでは32や64などが試されることが多いですが、データセットや目的に応じて調整が必要です。 +* `--network_alpha=M` + * LoRAのアルファ値。`network_dim`の半分程度、または`network_dim`と同じ値などが一般的です。デフォルトは1。 +* `--network_dropout=P` + * LoRAモジュール内のドロップアウト率 (0.0~1.0)。過学習抑制の効果が期待できます。デフォルトはNone (ドロップアウトなし)。 +* `--network_args ...` + * ネットワークモジュールへの追加引数を `key=value` 形式で指定します。LoRAでは以下の高度な設定が可能です。 + * **階層別 (Block-wise) 次元数/アルファ:** + * U-Netの各ブロックごとに異なる`dim`と`alpha`を指定できます。これにより、特定の層の影響を強めたり弱めたりする調整が可能です。 + * `block_dims`: U-NetのLinear層およびConv2d 1x1層に対するブロックごとのdimをカンマ区切りで指定します (SDXLでは23個の数値)。 + * `block_alphas`: 上記に対応するalpha値をカンマ区切りで指定します。 + * `conv_block_dims`: U-NetのConv2d 3x3層に対するブロックごとのdimをカンマ区切りで指定します。 + * `conv_block_alphas`: 上記に対応するalpha値をカンマ区切りで指定します。 + * 指定しないブロックは `--network_dim`/`--network_alpha` または `--conv_dim`/`--conv_alpha` (存在する場合) の値が使用されます。 + * 詳細は[LoRA の階層別学習率](train_network.md#lora-の階層別学習率) (train\_network.md内、SDXLでも同様に適用可能) や実装 ([lora.py](lora.py)) を参照してください。 + * **LoRA+:** + * `loraplus_lr_ratio=R`: LoRAの上向き重み(UP)の学習率を、下向き重み(DOWN)の学習率のR倍にします。学習速度の向上が期待できます。論文推奨は16。 + * `loraplus_unet_lr_ratio=RU`: U-Net部分のLoRA+学習率比を個別に指定します。 + * `loraplus_text_encoder_lr_ratio=RT`: Text Encoder部分のLoRA+学習率比を個別に指定します。(`--text_encoder_lr1`, `--text_encoder_lr2`で指定した学習率に乗算されます) + * 詳細は[README](../README.md#jan-17-2025--2025-01-17-version-090)や実装 ([lora.py](lora.py)) を参照してください。 +* `--network_train_unet_only` + * U-NetのLoRAモジュールのみを学習します。Text Encoderの学習を行わない場合に指定します。`--cache_text_encoder_outputs` を使用する場合は必須です。 +* `--network_train_text_encoder_only` + * Text EncoderのLoRAモジュールのみを学習します。U-Netの学習を行わない場合に指定します。 +* `--network_weights="<重みファイル>"` + * 学習済みのLoRA重みを読み込んで学習を開始します。ファインチューニングや学習再開に使用します。`--resume` との違いは、このオプションはLoRAモジュールの重みのみを読み込み、`--resume` はOptimizerの状態や学習ステップ数なども復元します。 +* `--dim_from_weights` + * `--network_weights` で指定した重みファイルからLoRAの次元数 (`dim`) を自動的に読み込みます。`--network_dim` の指定は不要になります。 + +### 1.5. 学習パラメータ + +* `--learning_rate=LR` + * 全体の学習率。各モジュール(`unet_lr`, `text_encoder_lr1`, `text_encoder_lr2`)のデフォルト値となります。`1e-3` や `1e-4` などが試されることが多いです。 +* `--unet_lr=LR_U` + * U-Net部分のLoRAモジュールの学習率。 +* `--text_encoder_lr1=LR_TE1` + * Text Encoder 1 (OpenCLIP ViT-G/14) のLoRAモジュールの学習率。通常、U-Netより小さい値 (例: `1e-5`, `2e-5`) が推奨されます。 +* `--text_encoder_lr2=LR_TE2` + * Text Encoder 2 (CLIP ViT-L/14) のLoRAモジュールの学習率。通常、U-Netより小さい値 (例: `1e-5`, `2e-5`) が推奨されます。 +* `--optimizer_type="..."` + * 使用するOptimizerを指定します。`AdamW8bit` (省メモリ、一般的), `Adafactor` (さらに省メモリ、SDXLフルモデル学習で実績あり), `Lion`, `DAdaptation`, `Prodigy`などが選択可能です。各Optimizerには追加の引数が必要な場合があります (`--optimizer_args`参照)。 + * `AdamW8bit` や `PagedAdamW8bit` (要 `bitsandbytes`) が一般的です。 + * `Adafactor` はメモリ効率が良いですが、設定がやや複雑です (相対ステップ(`relative_step=True`)推奨、学習率スケジューラは`adafactor`推奨)。 + * `DAdaptation`, `Prodigy` は学習率の自動調整機能がありますが、LoRA+との併用はできません。学習率は`1.0`程度を指定します。 + * 詳細は[train\_util.py](train_util.py)の`get_optimizer`関数を参照してください。 +* `--optimizer_args ...` + * Optimizerへの追加引数を `key=value` 形式で指定します (例: `"weight_decay=0.01"` `"betas=0.9,0.999"`). +* `--lr_scheduler="..."` + * 学習率スケジューラを指定します。`constant` (変化なし), `cosine` (コサインカーブ), `linear` (線形減衰), `constant_with_warmup` (ウォームアップ付き定数), `cosine_with_restarts` など。`constant` や `cosine` 、 `constant_with_warmup` がよく使われます。 + * スケジューラによっては追加の引数が必要です (`--lr_scheduler_args`参照)。 + * `DAdaptation` や `Prodigy` などの自己学習率調整機能付きOptimizerを使用する場合、スケジューラは不要です (`constant` を指定)。 +* `--lr_warmup_steps=N` + * 学習率スケジューラのウォームアップステップ数。学習開始時に学習率を徐々に上げていく期間です。N < 1 の場合は全ステップ数に対する割合と解釈されます。 +* `--lr_scheduler_num_cycles=N` / `--lr_scheduler_power=P` + * 特定のスケジューラ (`cosine_with_restarts`, `polynomial`) のためのパラメータ。 +* `--max_train_steps=N` / `--max_train_epochs=N` + * 学習の総ステップ数またはエポック数を指定します。エポック指定が優先されます。 +* `--mixed_precision="bf16"` / `"fp16"` / `"no"` + * 混合精度学習の設定。SDXLでは `bf16` (対応GPUの場合) または `fp16` の使用が強く推奨されます。VRAM使用量を削減し、学習速度を向上させます。 +* `--full_fp16` / `--full_bf16` + * 勾配計算も含めて完全に半精度/bf16で行います。VRAM使用量をさらに削減できますが、学習の安定性に影響する可能性があります。VRAMがどうしても足りない場合に使用します。 +* `--gradient_accumulation_steps=N` + * 勾配をNステップ分蓄積してからOptimizerを更新します。実質的なバッチサイズを `train_batch_size * N` に増やし、少ないVRAMで大きなバッチサイズ相当の効果を得られます。デフォルトは1。 +* `--max_grad_norm=N` + * 勾配クリッピングの閾値。勾配のノルムがNを超える場合にクリッピングします。デフォルトは1.0。`0`で無効。 +* `--gradient_checkpointing` + * メモリ使用量を大幅に削減しますが、学習速度は若干低下します。SDXLではメモリ消費が大きいため、有効にすることが推奨されます。 +* `--fused_backward_pass` + * **実験的機能:** 勾配計算とOptimizerのステップを融合し、VRAM使用量を削減します。SDXLで利用可能です。現在 `Adafactor` Optimizerのみ対応。Gradient Accumulationとは併用できません。 +* `--resume=""` + * `--save_state`で保存された学習状態から学習を再開します。Optimizerの状態や学習ステップ数などが復元されます。 + +### 1.6. キャッシュ機能関連 + +SDXLは計算コストが高いため、キャッシュ機能が効果的です。 + +* `--cache_latents` + * VAEの出力(Latent)をメモリにキャッシュします。VAEの計算を省略でき、VRAM使用量を削減し、学習を高速化します。**注意:** 画像に対するAugmentation (`color_aug`, `flip_aug`, `random_crop` 等) は無効になります。 +* `--cache_latents_to_disk` + * `--cache_latents` と併用し、キャッシュ先をディスクにします。大量のデータセットや複数回の学習で特に有効です。初回実行時にディスクにキャッシュが生成され、2回目以降はそれを読み込みます。 +* `--cache_text_encoder_outputs` + * Text Encoderの出力をメモリにキャッシュします。Text Encoderの計算を省略でき、VRAM使用量を削減し、学習を高速化します。**注意:** キャプションに対するAugmentation (`shuffle_caption`, `caption_dropout_rate` 等) は無効になります。**また、このオプションを使用する場合、Text EncoderのLoRAモジュールは学習できません (`--network_train_unet_only` の指定が必須です)。** +* `--cache_text_encoder_outputs_to_disk` + * `--cache_text_encoder_outputs` と併用し、キャッシュ先をディスクにします。 +* `--skip_cache_check` + * キャッシュファイルの内容の検証をスキップします。ファイルの存在確認は行われ、存在しない場合はキャッシュが生成されます。デバッグ等で意図的に再キャッシュしたい場合を除き、通常は指定不要です。 + +### 1.7. サンプル画像生成関連 + +基本的なオプションは `train_network.py` と共通です。 + +* `--sample_every_n_steps=N` / `--sample_every_n_epochs=N` + * Nステップ/エポックごとにサンプル画像を生成します。 +* `--sample_at_first` + * 学習開始前にサンプル画像を生成します。 +* `--sample_prompts="<プロンプトファイル>"` + * サンプル画像生成に使用するプロンプトを記述したファイル (`.txt`, `.toml`, `.json`) を指定します。書式は[gen\_img\_diffusers.py](gen_img_diffusers.py)に準じます。詳細は[ドキュメント](gen_img_README-ja.md)を参照してください。 +* `--sample_sampler="..."` + * サンプル画像生成時のサンプラー(スケジューラ)を指定します。`euler_a`, `dpm++_2m_karras` などが一般的です。選択肢は `--help` を参照してください。 + +### 1.8. Logging & Tracking 関連 + +* `--logging_dir="<ログディレクトリ>"` + * TensorBoardなどのログを出力するディレクトリを指定します。指定しない場合、ログは出力されません。 +* `--log_with="tensorboard"` / `"wandb"` / `"all"` + * 使用するログツールを指定します。`wandb`を使用する場合、`pip install wandb`が必要です。 +* `--log_prefix="<プレフィックス>"` + * `logging_dir` 内に作成されるサブディレクトリ名の接頭辞を指定します。 +* `--wandb_api_key=""` / `--wandb_run_name="<実行名>"` + * Weights & Biases (wandb) 使用時のオプション。 +* `--log_tracker_name` / `--log_tracker_config` + * 高度なトラッカー設定用オプション。通常は指定不要。 +* `--log_config` + * 学習開始時に、使用された学習設定(一部の機密情報を除く)をログに出力します。再現性の確保に役立ちます。 + +### 1.9. 正則化・高度な学習テクニック関連 + +* `--noise_offset=N` + * ノイズオフセットを有効にし、その値を指定します。画像の明るさやコントラストの偏りを改善する効果が期待できます。SDXLのベースモデルはこの値で学習されているため、有効にすることが推奨されます (例: 0.0357)。元々の技術解説は[こちら](https://www.crosslabs.org/blog/diffusion-with-offset-noise)。 +* `--noise_offset_random_strength` + * ノイズオフセットの強度を0から指定値の間でランダムに変動させます。 +* `--adaptive_noise_scale=N` + * Latentの平均絶対値に応じてノイズオフセットを調整します。`--noise_offset`と併用します。 +* `--multires_noise_iterations=N` / `--multires_noise_discount=D` + * 複数解像度ノイズを有効にします。異なる周波数成分のノイズを加えることで、ディテールの再現性を向上させる効果が期待できます。イテレーション回数N (6-10程度) と割引率D (0.3程度) を指定します。技術解説は[こちら](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2)。 +* `--ip_noise_gamma=G` / `--ip_noise_gamma_random_strength` + * Input Perturbation Noiseを有効にします。入力(Latent)に微小なノイズを加えて正則化を行います。Gamma値 (0.1程度) を指定します。`random_strength`で強度をランダム化できます。 +* `--min_snr_gamma=N` + * Min-SNR Weighting Strategy を適用します。学習初期のノイズが大きいタイムステップでのLossの重みを調整し、学習を安定させます。`N=5` などが使用されます。 +* `--scale_v_pred_loss_like_noise_pred` + * v-predictionモデルにおいて、vの予測ロスをノイズ予測ロスと同様のスケールに調整します。SDXLはv-predictionではないため、**通常は使用しません**。 +* `--v_pred_like_loss=N` + * ノイズ予測モデルにv予測ライクなロスを追加します。`N`でその重みを指定します。SDXLでは**通常は使用しません**。 +* `--debiased_estimation_loss` + * Debiased EstimationによるLoss計算を行います。Min-SNRと類似の目的を持ちますが、異なるアプローチです。 +* `--loss_type="l1"` / `"l2"` / `"huber"` / `"smooth_l1"` + * 損失関数を指定します。デフォルトは`l2` (MSE)。`huber`や`smooth_l1`は外れ値に頑健な損失関数です。 +* `--huber_schedule="constant"` / `"exponential"` / `"snr"` + * `huber`または`smooth_l1`損失使用時のスケジューリング方法。`snr`が推奨されています。 +* `--huber_c=C` / `--huber_scale=S` + * `huber`または`smooth_l1`損失のパラメータ。 +* `--masked_loss` + * マスク画像に基づいてLoss計算領域を限定します。データセット設定で`conditioning_data_dir`にマスク画像(白黒)を指定する必要があります。詳細は[マスクロスについて](masked_loss_README.md)を参照してください。 + +### 1.10. 分散学習、その他学習関連 + +* `--seed=N` + * 乱数シードを指定します。学習の再現性を確保したい場合に設定します。 +* `--max_token_length=N` (`75`, `150`, `225`) + * Text Encoderが処理するトークンの最大長。SDXLでは通常`75` (デフォルト) または `150`, `225`。長くするとより複雑なプロンプトを扱えますが、VRAM使用量が増加します。 +* `--clip_skip=N` + * Text Encoderの最終層からN層スキップした層の出力を使用します。SDXLでは**通常使用しません**。 +* `--lowram` / `--highvram` + * メモリ使用量の最適化に関するオプション。`--lowram`はColabなどRAM < VRAM環境向け、`--highvram`はVRAM潤沢な環境向け。 +* `--persistent_data_loader_workers` / `--max_data_loader_n_workers=N` + * DataLoaderのワーカプロセスに関する設定。エポック間の待ち時間やメモリ使用量に影響します。 +* `--config_file="<設定ファイル>"` / `--output_config` + * コマンドライン引数の代わりに`.toml`ファイルを使用/出力するオプション。 +* **Accelerate/DeepSpeed関連:** (`--ddp_timeout`, `--ddp_gradient_as_bucket_view`, `--ddp_static_graph`) + * 分散学習時の詳細設定。通常はAccelerateの設定 (`accelerate config`) で十分です。DeepSpeedを使用する場合は、別途設定が必要です。 +* `--initial_epoch=` – 開始エポック番号を設定します。`1`で最初のエポック(未指定時と同じ)。注意:`initial_epoch`/`initial_step`はlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まります。 +* `--initial_step=` – 全エポックを含む開始ステップ番号を設定します。`0`で最初のステップ(未指定時と同じ)。`initial_epoch`を上書きします。 +* `--skip_until_initial_step` – `initial_step`に到達するまで学習をスキップします。 + +### 1.11. コンソールとログ + +* `--console_log_level`: コンソール出力のログレベルを設定します。`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`から選択します。 +* `--console_log_file`: コンソールのログを指定されたファイルに出力します。 +* `--console_log_simple`: よりシンプルなログフォーマットを有効にします。 + +### 1.12. Hugging Face Hub 連携 + +* `--huggingface_repo_id`: モデルをアップロードするHugging Face Hubのリポジトリ名 (例: `your-username/your-model`)。 +* `--huggingface_repo_type`: Hugging Face Hubのリポジトリの種類。通常は`model`です。 +* `--huggingface_path_in_repo`: リポジトリ内でファイルをアップロードするパス。 +* `--huggingface_token`: Hugging Face Hubの認証トークン。 +* `--huggingface_repo_visibility`: リポジトリの公開設定 (`public`または`private`)。 +* `--resume_from_huggingface`: Hugging Face Hubに保存された状態から学習を再開します。 +* `--async_upload`: Hubへのモデルの非同期アップロードを有効にし、学習プロセスをブロックしないようにします。 +* `--save_n_epoch_ratio`: 総エポック数に対する特定の比率でモデルを保存します。例えば`5`を指定すると、学習全体で少なくとも5つのチェックポイントが保存されます。 + +### 1.13. 高度なAttention設定 + +* `--mem_eff_attn`: メモリ効率の良いAttentionメカニズムを使用します。これは古い実装であり、一般的には`sdpa`や`xformers`の使用が推奨されます。 +* `--xformers`: メモリ効率の良いAttentionのためにxformersライブラリを使用します。`pip install xformers`が必要です。 + +### 1.14. 高度な学習率スケジューラ設定 + +* `--lr_scheduler_type`: カスタムスケジューラモジュールを指定します。 +* `--lr_scheduler_args`: カスタムスケジューラに追加の引数を渡します (例: `"T_max=100"`)。 +* `--lr_decay_steps`: 学習率が減衰するステップ数を設定します。 +* `--lr_scheduler_timescale`: 逆平方根スケジューラのタイムスケール。 +* `--lr_scheduler_min_lr_ratio`: 特定のスケジューラについて、初期学習率に対する最小学習率の比率を設定します。 + +### 1.15. LoRAの差分学習 + +既存の学習済みLoRAをベースモデルにマージしてから、新たな学習を開始する手法です。既存LoRAのファインチューニングや、差分を学習させたい場合に有効です。 + +* `--base_weights`: 学習開始前にベースモデルにマージするLoRAの重みファイルを1つ以上指定します。 +* `--base_weights_multiplier`: `--base_weights`で指定したLoRAの重みの倍率。複数指定も可能です。 + +### 1.16. その他のオプション + +* `--tokenizer_cache_dir`: オフラインでの学習に便利なように、tokenizerをキャッシュするディレクトリを指定します。 +* `--scale_weight_norms`: LoRAモジュールの重みのノルムをスケーリングします。重みの大きさを制御することで過学習を防ぐ助けになります。`1.0`が良い出発点です。 +* `--disable_mmap_load_safetensors`: `.safetensors`ファイルのメモリマップドローディングを無効にします。WSLなどの一部環境でモデルの読み込みを高速化できます。 + +## 2. その他のTips + + +* **VRAM使用量:** SDXL LoRA学習は多くのVRAMを必要とします。24GB VRAMでも設定によってはメモリ不足になることがあります。以下の設定でVRAM使用量を削減できます。 + * `--mixed_precision="bf16"` または `"fp16"` (必須級) + * `--gradient_checkpointing` (強く推奨) + * `--cache_latents` / `--cache_text_encoder_outputs` (効果大、制約あり) + * `--optimizer_type="AdamW8bit"` または `"Adafactor"` + * `--gradient_accumulation_steps` の値を増やす (バッチサイズを小さくする) + * `--full_fp16` / `--full_bf16` (安定性に注意) + * `--fp8_base` / `--fp8_base_unet` (実験的) + * `--fused_backward_pass` (Adafactor限定、実験的) +* **学習率:** SDXL LoRAの適切な学習率はデータセットや`network_dim`/`alpha`に依存します。`1e-4` ~ `4e-5` (U-Net), `1e-5` ~ `2e-5` (Text Encoders) あたりから試すのが一般的です。 +* **学習時間:** 高解像度データとSDXLモデルのサイズのため、学習には時間がかかります。キャッシュ機能や適切なハードウェアの利用が重要です。 +* **トラブルシューティング:** + * **NaN Loss:** 学習率が高すぎる、混合精度の設定が不適切 (`fp16`時の`--no_half_vae`未指定など)、データセットの問題などが考えられます。 + * **VRAM不足 (OOM):** 上記のVRAM削減策を試してください。 + * **学習が進まない:** 学習率が低すぎる、Optimizer/Schedulerの設定が不適切、データセットの問題などが考えられます。 + +## 3. おわりに + +`sdxl_train_network.py` は非常に多くのオプションを提供しており、SDXL LoRA学習の様々な側面をカスタマイズできます。このドキュメントが、より高度な設定やチューニングを行う際の助けとなれば幸いです。 + +不明な点や詳細については、各スクリプトの `--help` オプションや、リポジトリ内の他のドキュメント、実装コード自体を参照してください。 + +
diff --git a/docs/train_textual_inversion.md b/docs/train_textual_inversion.md new file mode 100644 index 000000000..b7c69eb7b --- /dev/null +++ b/docs/train_textual_inversion.md @@ -0,0 +1,291 @@ +# How to use Textual Inversion training scripts / Textual Inversion学習スクリプトの使い方 + +This document explains how to train Textual Inversion embeddings using the `train_textual_inversion.py` and `sdxl_train_textual_inversion.py` scripts included in the `sd-scripts` repository. + +
+日本語 +このドキュメントでは、`sd-scripts` リポジトリに含まれる `train_textual_inversion.py` および `sdxl_train_textual_inversion.py` を使用してTextual Inversionの埋め込みを学習する方法について解説します。 +
+ +## 1. Introduction / はじめに + +[Textual Inversion](https://textual-inversion.github.io/) is a technique that teaches Stable Diffusion new concepts by learning new token embeddings. Instead of fine-tuning the entire model, it only optimizes the text encoder's token embeddings, making it a lightweight approach to teaching the model specific characters, objects, or artistic styles. + +**Available Scripts:** +- `train_textual_inversion.py`: For Stable Diffusion v1.x and v2.x models +- `sdxl_train_textual_inversion.py`: For Stable Diffusion XL models + +**Prerequisites:** +* The `sd-scripts` repository has been cloned and the Python environment has been set up. +* The training dataset has been prepared. For dataset preparation, please refer to the [Dataset Configuration Guide](config_README-en.md). + +
+日本語 + +[Textual Inversion](https://textual-inversion.github.io/) は、新しいトークンの埋め込みを学習することで、Stable Diffusionに新しい概念を教える技術です。モデル全体をファインチューニングする代わりに、テキストエンコーダのトークン埋め込みのみを最適化するため、特定のキャラクター、オブジェクト、芸術的スタイルをモデルに教えるための軽量なアプローチです。 + +**利用可能なスクリプト:** +- `train_textual_inversion.py`: Stable Diffusion v1.xおよびv2.xモデル用 +- `sdxl_train_textual_inversion.py`: Stable Diffusion XLモデル用 + +**前提条件:** +* `sd-scripts` リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。データセットの準備については[データセット設定ガイド](config_README-en.md)を参照してください。 +
+ +## 2. Basic Usage / 基本的な使用方法 + +### 2.1. For Stable Diffusion v1.x/v2.x Models / Stable Diffusion v1.x/v2.xモデル用 + +```bash +accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py \ + --pretrained_model_name_or_path="path/to/model.safetensors" \ + --dataset_config="dataset_config.toml" \ + --output_dir="output" \ + --output_name="my_textual_inversion" \ + --save_model_as="safetensors" \ + --token_string="mychar" \ + --init_word="girl" \ + --num_vectors_per_token=4 \ + --max_train_steps=1600 \ + --learning_rate=1e-6 \ + --optimizer_type="AdamW8bit" \ + --mixed_precision="fp16" \ + --cache_latents \ + --sdpa +``` + +### 2.2. For SDXL Models / SDXLモデル用 + +```bash +accelerate launch --num_cpu_threads_per_process 1 sdxl_train_textual_inversion.py \ + --pretrained_model_name_or_path="path/to/sdxl_model.safetensors" \ + --dataset_config="dataset_config.toml" \ + --output_dir="output" \ + --output_name="my_sdxl_textual_inversion" \ + --save_model_as="safetensors" \ + --token_string="mychar" \ + --init_word="girl" \ + --num_vectors_per_token=4 \ + --max_train_steps=1600 \ + --learning_rate=1e-6 \ + --optimizer_type="AdamW8bit" \ + --mixed_precision="fp16" \ + --cache_latents \ + --sdpa +``` + +
+日本語 +上記のコマンドは実際には1行で書く必要がありますが、見やすさのために改行しています(LinuxやMacでは行末に `\` を追加することで改行できます)。Windowsの場合は、改行せずに1行で書くか、`^` を行末に追加してください。 +
+ +## 3. Key Command-Line Arguments / 主要なコマンドライン引数 + +### 3.1. Textual Inversion Specific Arguments / Textual Inversion固有の引数 + +#### Core Parameters / コアパラメータ + +* `--token_string="mychar"` **[Required]** + * Specifies the token string used in training. This must not exist in the tokenizer's vocabulary. In your training prompts, include this token string (e.g., if token_string is "mychar", use prompts like "mychar 1girl"). + * 学習時に使用されるトークン文字列を指定します。tokenizerの語彙に存在しない文字である必要があります。学習時のプロンプトには、このトークン文字列を含める必要があります(例:token_stringが"mychar"なら、"mychar 1girl"のようなプロンプトを使用)。 + +* `--init_word="girl"` + * Specifies the word to use for initializing the embedding vector. Choose a word that is conceptually close to what you want to teach. Must be a single token. + * 埋め込みベクトルの初期化に使用する単語を指定します。教えたい概念に近い単語を選ぶとよいでしょう。単一のトークンである必要があります。 + +* `--num_vectors_per_token=4` + * Specifies how many embedding vectors to use for this token. More vectors provide greater expressiveness but consume more tokens from the 77-token limit. + * このトークンに使用する埋め込みベクトルの数を指定します。多いほど表現力が増しますが、77トークン制限からより多くのトークンを消費します。 + +* `--weights="path/to/existing_embedding.safetensors"` + * Loads pre-trained embeddings to continue training from. Optional parameter for transfer learning. + * 既存の埋め込みを読み込んで、そこから追加で学習します。転移学習のオプションパラメータです。 + +#### Template Options / テンプレートオプション + +* `--use_object_template` + * Ignores captions and uses predefined object templates (e.g., "a photo of a {}"). Same as the original implementation. + * キャプションを無視して、事前定義された物体用テンプレート(例:"a photo of a {}")を使用します。公式実装と同じです。 + +* `--use_style_template` + * Ignores captions and uses predefined style templates (e.g., "a painting in the style of {}"). Same as the original implementation. + * キャプションを無視して、事前定義されたスタイル用テンプレート(例:"a painting in the style of {}")を使用します。公式実装と同じです。 + +### 3.2. Model and Dataset Arguments / モデル・データセット引数 + +For common model and dataset arguments, please refer to [LoRA Training Guide](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数). The following arguments work the same way: + +* `--pretrained_model_name_or_path` +* `--dataset_config` +* `--v2`, `--v_parameterization` +* `--resolution` +* `--cache_latents`, `--vae_batch_size` +* `--enable_bucket`, `--min_bucket_reso`, `--max_bucket_reso` + +
+日本語 +一般的なモデル・データセット引数については、[LoRA学習ガイド](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数)を参照してください。以下の引数は同様に動作します: + +* `--pretrained_model_name_or_path` +* `--dataset_config` +* `--v2`, `--v_parameterization` +* `--resolution` +* `--cache_latents`, `--vae_batch_size` +* `--enable_bucket`, `--min_bucket_reso`, `--max_bucket_reso` +
+ +### 3.3. Training Parameters / 学習パラメータ + +For training parameters, please refer to [LoRA Training Guide](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数). Textual Inversion typically uses these settings: + +* `--learning_rate=1e-6`: Lower learning rates are often used compared to LoRA training +* `--max_train_steps=1600`: Fewer steps are usually sufficient +* `--optimizer_type="AdamW8bit"`: Memory-efficient optimizer +* `--mixed_precision="fp16"`: Reduces memory usage + +**Note:** Textual Inversion has lower memory requirements compared to full model fine-tuning, so you can often use larger batch sizes. + +
+日本語 +学習パラメータについては、[LoRA学習ガイド](train_network.md#31-main-command-line-arguments--主要なコマンドライン引数)を参照してください。Textual Inversionでは通常以下の設定を使用します: + +* `--learning_rate=1e-6`: LoRA学習と比べて低い学習率がよく使用されます +* `--max_train_steps=1600`: より少ないステップで十分な場合が多いです +* `--optimizer_type="AdamW8bit"`: メモリ効率的なオプティマイザ +* `--mixed_precision="fp16"`: メモリ使用量を削減 + +**注意:** Textual Inversionはモデル全体のファインチューニングと比べてメモリ要件が低いため、多くの場合、より大きなバッチサイズを使用できます。 +
+ +## 4. Dataset Preparation / データセット準備 + +### 4.1. Dataset Configuration / データセット設定 + +Create a TOML configuration file as described in the [Dataset Configuration Guide](config_README-en.md). Here's an example for Textual Inversion: + +```toml +[general] +shuffle_caption = false +caption_extension = ".txt" +keep_tokens = 1 + +[[datasets]] +resolution = 512 # 1024 for SDXL +batch_size = 4 # Can use larger values than LoRA training +enable_bucket = true + + [[datasets.subsets]] + image_dir = "path/to/images" + caption_extension = ".txt" + num_repeats = 10 +``` + +### 4.2. Caption Guidelines / キャプションガイドライン + +**Important:** Your captions must include the token string you specified. For example: + +* If `--token_string="mychar"`, captions should be like: "mychar, 1girl, blonde hair, blue eyes" +* The token string can appear anywhere in the caption, but including it is essential + +You can verify that your token string is being recognized by using `--debug_dataset`, which will show token IDs. Look for tokens with IDs ≥ 49408 (these are the new custom tokens). + +
+日本語 + +**重要:** キャプションには指定したトークン文字列を含める必要があります。例: + +* `--token_string="mychar"` の場合、キャプションは "mychar, 1girl, blonde hair, blue eyes" のようにします +* トークン文字列はキャプション内のどこに配置しても構いませんが、含めることが必須です + +`--debug_dataset` を使用してトークン文字列が認識されているかを確認できます。これによりトークンIDが表示されます。ID ≥ 49408 のトークン(これらは新しいカスタムトークン)を探してください。 +
+ +## 5. Advanced Configuration / 高度な設定 + +### 5.1. Multiple Token Vectors / 複数トークンベクトル + +When using `--num_vectors_per_token` > 1, the system creates additional token variations: +- `--token_string="mychar"` with `--num_vectors_per_token=4` creates: "mychar", "mychar1", "mychar2", "mychar3" + +For generation, you can use either the base token or all tokens together. + +### 5.2. Memory Optimization / メモリ最適化 + +* Use `--cache_latents` to cache VAE outputs and reduce VRAM usage +* Use `--gradient_checkpointing` for additional memory savings +* For SDXL, use `--cache_text_encoder_outputs` to cache text encoder outputs +* Consider using `--mixed_precision="bf16"` on newer GPUs (RTX 30 series and later) + +### 5.3. Training Tips / 学習のコツ + +* **Learning Rate:** Start with 1e-6 and adjust based on results. Lower rates often work better than LoRA training. +* **Steps:** 1000-2000 steps are usually sufficient, but this varies by dataset size and complexity. +* **Batch Size:** Textual Inversion can handle larger batch sizes than full fine-tuning due to lower memory requirements. +* **Templates:** Use `--use_object_template` for characters/objects, `--use_style_template` for artistic styles. + +
+日本語 + +* **学習率:** 1e-6から始めて、結果に基づいて調整してください。LoRA学習よりも低い率がよく機能します。 +* **ステップ数:** 通常1000-2000ステップで十分ですが、データセットのサイズと複雑さによって異なります。 +* **バッチサイズ:** メモリ要件が低いため、Textual Inversionは完全なファインチューニングよりも大きなバッチサイズを処理できます。 +* **テンプレート:** キャラクター/オブジェクトには `--use_object_template`、芸術的スタイルには `--use_style_template` を使用してください。 +
+ +## 6. Usage After Training / 学習後の使用方法 + +The trained Textual Inversion embeddings can be used in: + +* **Automatic1111 WebUI:** Place the `.safetensors` file in the `embeddings` folder +* **ComfyUI:** Use the embedding file with appropriate nodes +* **Other Diffusers-based applications:** Load using the embedding path + +In your prompts, simply use the token string you trained (e.g., "mychar") and the model will use the learned embedding. + +
+日本語 + +学習したTextual Inversionの埋め込みは以下で使用できます: + +* **Automatic1111 WebUI:** `.safetensors` ファイルを `embeddings` フォルダに配置 +* **ComfyUI:** 適切なノードで埋め込みファイルを使用 +* **その他のDiffusersベースアプリケーション:** 埋め込みパスを使用して読み込み + +プロンプトでは、学習したトークン文字列(例:"mychar")を単純に使用するだけで、モデルが学習した埋め込みを使用します。 +
+ +## 7. Troubleshooting / トラブルシューティング + +### Common Issues / よくある問題 + +1. **Token string already exists in tokenizer** + * Use a unique string that doesn't exist in the model's vocabulary + * Try adding numbers or special characters (e.g., "mychar123") + +2. **No improvement after training** + * Ensure your captions include the token string + * Try adjusting the learning rate (lower values like 5e-7) + * Increase the number of training steps + + * Use `--cache_latents` + +
+日本語 + +1. **トークン文字列がtokenizerに既に存在する** + * モデルの語彙に存在しない固有の文字列を使用してください + * 数字や特殊文字を追加してみてください(例:"mychar123") + +2. **学習後に改善が見られない** + * キャプションにトークン文字列が含まれていることを確認してください + * 学習率を調整してみてください(5e-7のような低い値) + * 学習ステップ数を増やしてください + +3. **メモリ不足エラー** + * データセット設定でバッチサイズを減らしてください + * `--gradient_checkpointing` を使用してください + * `--cache_latents` を使用してください +
+ +For additional training options and advanced configurations, please refer to the [LoRA Training Guide](train_network.md) as many parameters are shared between training methods. \ No newline at end of file diff --git a/docs/validation.md b/docs/validation.md new file mode 100644 index 000000000..7f6a008c2 --- /dev/null +++ b/docs/validation.md @@ -0,0 +1,261 @@ +# Validation Loss + +Validation loss is a crucial metric for monitoring the training process of a model. It helps you assess how well your model is generalizing to data it hasn't seen during training, which is essential for preventing overfitting. By periodically evaluating the model on a separate validation dataset, you can gain insights into its performance and make more informed decisions about when to stop training or adjust hyperparameters. + +This feature provides a stable and reliable validation loss metric by ensuring the validation process is deterministic. + +
+日本語 + +Validation loss(検証損失)は、モデルの学習過程を監視するための重要な指標です。モデルが学習中に見ていないデータに対してどの程度汎化できているかを評価するのに役立ち、過学習を防ぐために不可欠です。個別の検証データセットで定期的にモデルを評価することで、そのパフォーマンスに関する洞察を得て、学習をいつ停止するか、またはハイパーパラメータを調整するかについて、より多くの情報に基づいた決定を下すことができます。 + +この機能は、検証プロセスが決定論的であることを保証することにより、安定して信頼性の高い検証損失指標を提供します。 + +
+ +## How It Works + +When validation is enabled, a portion of your dataset is set aside specifically for this purpose. The script then runs a validation step at regular intervals, calculating the loss on this validation data. + +To ensure that the validation loss is a reliable indicator of model performance, the process is deterministic. This means that for every validation run, the same random seed is used for noise generation and timestep selection. This consistency ensures that any fluctuations in the validation loss are due to changes in the model's weights, not random variations in the validation process itself. + +The average loss across all validation steps is then logged, providing a single, clear metric to track. + +For more technical details, please refer to the original pull request: [PR #1903](https://github.com/kohya-ss/sd-scripts/pull/1903). + +
+日本語 + +検証が有効になると、データセットの一部がこの目的のために特別に確保されます。スクリプトは定期的な間隔で検証ステップを実行し、この検証データに対する損失を計算します。 + +検証損失がモデルのパフォーマンスの信頼できる指標であることを保証するために、プロセスは決定論的です。つまり、すべての検証実行で、ノイズ生成とタイムステップ選択に同じランダムシードが使用されます。この一貫性により、検証損失の変動が、検証プロセス自体のランダムな変動ではなく、モデルの重みの変化によるものであることが保証されます。 + +すべての検証ステップにわたる平均損失がログに記録され、追跡するための単一の明確な指標が提供されます。 + +より技術的な詳細については、元のプルリクエストを参照してください: [PR #1903](https://github.com/kohya-ss/sd-scripts/pull/1903). + +
+ +## How to Use + +### Enabling Validation + +There are two primary ways to enable validation: + +1. **Using a Dataset Config File (Recommended)**: You can specify a validation set directly within your dataset `.toml` file. This method offers the most control, allowing you to designate entire directories as validation sets or split a percentage of a specific subset for validation. + + To use a whole directory for validation, add a subset and set `validation_split = 1.0`. + + **Example: Separate Validation Set** + ```toml + [[datasets]] + # ... training subset ... + [[datasets.subsets]] + image_dir = "path/to/train_images" + # ... other settings ... + + # Validation subset + [[datasets.subsets]] + image_dir = "path/to/validation_images" + validation_split = 1.0 # Use this entire subset for validation + ``` + + To use a fraction of a subset for validation, set `validation_split` to a value between 0.0 and 1.0. + + **Example: Splitting a Subset** + ```toml + [[datasets]] + # ... dataset settings ... + [[datasets.subsets]] + image_dir = "path/to/images" + validation_split = 0.1 # Use 10% of this subset for validation + ``` + +2. **Using a Command-Line Argument**: For a simpler setup, you can use the `--validation_split` argument. This will take a random percentage of your *entire* training dataset for validation. This method is ignored if `validation_split` is defined in your dataset config file. + + **Example Command:** + ```bash + accelerate launch train_network.py ... --validation_split 0.1 + ``` + This command will use 10% of the total training data for validation. + +
+日本語 + +### 検証を有効にする + +検証を有効にする主な方法は2つあります。 + +1. **データセット設定ファイルを使用する(推奨)**: データセットの`.toml`ファイル内で直接検証セットを指定できます。この方法は最も制御性が高く、ディレクトリ全体を検証セットとして指定したり、特定のサブセットのパーセンテージを検証用に分割したりすることができます。 + + ディレクトリ全体を検証に使用するには、サブセットを追加して`validation_split = 1.0`と設定します。 + + **例:個別の検証セット** + ```toml + [[datasets]] + # ... training subset ... + [[datasets.subsets]] + image_dir = "path/to/train_images" + # ... other settings ... + + # Validation subset + [[datasets.subsets]] + image_dir = "path/to/validation_images" + validation_split = 1.0 # このサブセット全体を検証に使用します + ``` + + サブセットの一部を検証に使用するには、`validation_split`を0.0から1.0の間の値に設定します。 + + **例:サブセットの分割** + ```toml + [[datasets]] + # ... dataset settings ... + [[datasets.subsets]] + image_dir = "path/to/images" + validation_split = 0.1 # このサブセットの10%を検証に使用します + ``` + +2. **コマンドライン引数を使用する**: より簡単な設定のために、`--validation_split`引数を使用できます。これにより、*全*学習データセットのランダムなパーセンテージが検証に使用されます。この方法は、データセット設定ファイルで`validation_split`が定義されている場合は無視されます。 + + **コマンド例:** + ```bash + accelerate launch train_network.py ... --validation_split 0.1 + ``` + このコマンドは、全学習データの10%を検証に使用します。 + +
+ +### Configuration Options + +| Argument | TOML Option | Description | +| --------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| `--validation_split` | `validation_split` | The fraction of the dataset to use for validation. The command-line argument applies globally, while the TOML option applies per-subset. The TOML setting takes precedence. | +| `--validate_every_n_steps` | | Run validation every N steps. | +| `--validate_every_n_epochs` | | Run validation every N epochs. If not specified, validation runs once per epoch by default. | +| `--max_validation_steps` | | The maximum number of batches to use for a single validation run. If not set, the entire validation dataset is used. | +| `--validation_seed` | `validation_seed` | A specific seed for the validation dataloader shuffling. If not set in the TOML file, the main training `--seed` is used. | + +
+日本語 + +### 設定オプション + +| 引数 | TOMLオプション | 説明 | +| --------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| `--validation_split` | `validation_split` | 検証に使用するデータセットの割合。コマンドライン引数は全体に適用され、TOMLオプションはサブセットごとに適用されます。TOML設定が優先されます。 | +| `--validate_every_n_steps` | | Nステップごとに検証を実行します。 | +| `--validate_every_n_epochs` | | Nエポックごとに検証を実行します。指定しない場合、デフォルトでエポックごとに1回検証が実行されます。 | +| `--max_validation_steps` | | 1回の検証実行に使用するバッチの最大数。設定しない場合、検証データセット全体が使用されます。 | +| `--validation_seed` | `validation_seed` | 検証データローダーのシャッフル用の特定のシード。TOMLファイルで設定されていない場合、メインの学習`--seed`が使用されます。 | + +
+ +### Viewing the Results + +The validation loss is logged to your tracking tool of choice (TensorBoard or Weights & Biases). Look for the metric `loss/validation` to monitor the performance. + +
+日本語 + +### 結果の表示 + +検証損失は、選択した追跡ツール(TensorBoardまたはWeights & Biases)に記録されます。パフォーマンスを監視するには、`loss/validation`という指標を探してください。 + +
+ +### Practical Example + +Here is a complete example of how to run a LoRA training with validation enabled: + +**1. Prepare your `dataset_config.toml`:** + +```toml +[general] +shuffle_caption = true +keep_tokens = 1 + +[[datasets]] +resolution = "1024,1024" +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'path/to/your_images' + caption_extension = '.txt' + num_repeats = 10 + + [[datasets.subsets]] + image_dir = 'path/to/your_validation_images' + caption_extension = '.txt' + validation_split = 1.0 # Use this entire subset for validation +``` + +**2. Run the training command:** + +```bash +accelerate launch sdxl_train_network.py \ + --pretrained_model_name_or_path="sd_xl_base_1.0.safetensors" \ + --dataset_config="dataset_config.toml" \ + --output_dir="output" \ + --output_name="my_lora" \ + --network_module=networks.lora \ + --network_dim=32 \ + --network_alpha=16 \ + --save_every_n_epochs=1 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW8bit" \ + --mixed_precision="bf16" \ + --logging_dir=logs +``` + +The validation loss will be calculated once per epoch and saved to the `logs` directory, which you can view with TensorBoard. + +
+日本語 + +### 実践的な例 + +検証を有効にしてLoRAの学習を実行する完全な例を次に示します。 + +**1. `dataset_config.toml`を準備します:** + +```toml +[general] +shuffle_caption = true +keep_tokens = 1 + +[[datasets]] +resolution = "1024,1024" +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'path/to/your_images' + caption_extension = '.txt' + num_repeats = 10 + + [[datasets.subsets]] + image_dir = 'path/to/your_validation_images' + caption_extension = '.txt' + validation_split = 1.0 # このサブセット全体を検証に使用します +``` + +**2. 学習コマンドを実行します:** + +```bash +accelerate launch sdxl_train_network.py \ + --pretrained_model_name_or_path="sd_xl_base_1.0.safetensors" \ + --dataset_config="dataset_config.toml" \ + --output_dir="output" \ + --output_name="my_lora" \ + --network_module=networks.lora \ + --network_dim=32 \ + --network_alpha=16 \ + --save_every_n_epochs=1 \ + --learning_rate=1e-4 \ + --optimizer_type="AdamW8bit" \ + --mixed_precision="bf16" \ + --logging_dir=logs +``` + +検証損失はエポックごとに1回計算され、`logs`ディレクトリに保存されます。これはTensorBoardで表示できます。 + +
diff --git a/fine_tune.py b/fine_tune.py index c79f97d25..ffbbbb09f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -27,6 +27,7 @@ import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -39,6 +40,7 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +54,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,10 +91,11 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -167,8 +178,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -194,6 +206,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -216,7 +231,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -319,7 +338,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -344,25 +368,21 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -374,11 +394,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -390,9 +409,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -411,7 +428,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -436,7 +453,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -449,7 +466,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -474,7 +491,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: @@ -501,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 3bf71fed1..07a6510e6 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging @@ -42,10 +42,7 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - if size > IMAGE_SIZE: - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) - else: - image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) + image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE) image = image.astype(np.float32) return image diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py new file mode 100644 index 000000000..0664b3c78 --- /dev/null +++ b/flux_minimal_inference.py @@ -0,0 +1,596 @@ +# Minimum Inference Code for FLUX + +import argparse +import datetime +import math +import os +import random +from typing import Callable, List, Optional +import einops +import numpy as np + +import torch +from tqdm import tqdm +from PIL import Image +import accelerate +from transformers import CLIPTextModel +from safetensors.torch import load_file + +from library import device_utils +from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux + +init_ipex() + + +from library.utils import setup_logging, str_to_dtype + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import networks.lora_flux as lora_flux +from library import flux_models, flux_utils, sd3_utils, strategy_flux + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, +): + # prepare classifier free guidance + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") + do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0) + + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype) + + if do_cfg: + print("Using classifier free guidance") + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if do_cfg: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + + y_input = b_vec + + mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0]) + + pred = model( + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=y_input, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=b_t5_attn_mask, + mod_vectors=mod_vectors, + ) + + # classifier free guidance + if do_cfg: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + + img = img + (t_prev - t_curr) * pred + + return img + + +def do_sample( + accelerator: Optional[accelerate.Accelerator], + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: Optional[torch.Tensor], + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + t5_attn_mask: Optional[torch.Tensor], + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, +): + logger.info(f"num_steps: {num_steps}") + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise( + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, + ) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): + x = denoise( + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, + ) + + return x + + +def generate_image( + model, + clip_l: Optional[CLIPTextModel], + t5xxl, + ae, + prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: Optional[int], + guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, +): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + + # make first noise with packed shape + # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=noise_dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # prepare img and img ids + + # this is needed only for img2img + # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + # if img.shape[0] == 1 and bs > 1: + # img = repeat(img, "1 ... -> bs ...", bs=bs) + + # txt2img only needs img_ids + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare fp8 models + if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + + # prepare embeddings + logger.info("Encoding prompts...") + if clip_l is not None: + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if clip_l is not None: + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + l_pooled = None + + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask + + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None + + # NaN check + if l_pooled is not None and torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + if args.offload: + if clip_l is not None: + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + # del clip_l, t5xxl + device_utils.clean_memory() + + # generate image + logger.info("Generating image...") + model = model.to(device) + if steps is None: + steps = 4 if is_schnell else 50 + + img_ids = img_ids.to(device) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, + ) + if args.offload: + model = model.cpu() + # del model + device_utils.clean_memory() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if is_fp8(ae_dtype): + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + if args.offload: + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 768 # 1024 + target_width = 1360 # 1024 + + # steps = 50 # 28 # 50 + # guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use") + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--ae", type=str, required=False) + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") + parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") + parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") + parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + guidance_scale = args.guidance + + def is_fp8(dt): + return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] + + dtype = str_to_dtype(args.dtype) + clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) + t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) + ae_dtype = str_to_dtype(args.ae_dtype, dtype) + flux_dtype = str_to_dtype(args.flux_dtype, dtype) + + logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") + + loading_device = "cpu" if args.offload else device + + use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] + if any(use_fp8): + accelerator = accelerate.Accelerator(mixed_precision="bf16") + else: + accelerator = None + + # load clip_l (skip for chroma model) + if args.model_type == "flux": + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device, disable_mmap=True) + clip_l.eval() + else: + clip_l = None + + logger.info(f"Loading t5xxl from {args.t5xxl}...") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device, disable_mmap=True) + t5xxl.eval() + + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) + + # DiT + is_schnell, model = flux_utils.load_flow_model( + args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type + ) + model.eval() + logger.info(f"Casting model to {flux_dtype}") + model.to(flux_dtype) # make sure model is dtype + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") + + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # AE + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) + ae.eval() + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) + + # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = target_width + height = target_height + steps = None + guidance = args.guidance + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) + + logger.info("Done!") diff --git a/flux_train.py b/flux_train.py new file mode 100644 index 000000000..4aa67220f --- /dev/null +++ b/flux_train.py @@ -0,0 +1,851 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +from concurrent.futures import ThreadPoolExecutor +import copy +import math +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm + +import torch +import torch.nn as nn +from library import utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" + ) + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + flux.requires_grad_(True) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + flux = accelerator.unwrap_model(flux) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 000000000..019914058 --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,885 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value +from typing import List, Optional, Tuple, Union + +import toml +import torch +import torch.nn as nn +from tqdm import tqdm + +from library import utils +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from accelerate.utils import set_seed + +import library.train_util as train_util +import library.sai_model_spec as sai_model_spec +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + BlueprintGenerator, + ConfigSanitizer, +) +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + if args.model_type != "flux": + raise ValueError( + f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。" + ) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, args.conditioning_data_dir, args.caption_extension + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + is_schnell, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" + ) + flux.requires_grad_(False) + + # load controlnet + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) + controlnet.train() + + if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + flux.prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py new file mode 100644 index 000000000..cfc617088 --- /dev/null +++ b/flux_train_network.py @@ -0,0 +1,547 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional, Union + +import torch +from accelerate import Accelerator + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + self.model_type: Optional[str] = None + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + self.model_type = args.model_type # "flux" or "chroma" + if self.model_type != "chroma": + self.use_clip_l = True + else: + self.use_clip_l = False # Chroma does not use CLIP-L + assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります" + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only and self.use_clip_l + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + _, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, + loading_dtype, + "cpu", + disable_mmap=args.disable_mmap_load_safetensors, + model_type=self.model_type, + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if self.use_clip_l: + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + else: + clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA + return model_version, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + # This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here. + # Instead, we analyze the checkpoint state to determine if it is schnell. + if args.model_type != "chroma": + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + else: + is_schnell = False + self.is_schnell = is_schnell + + if args.t5xxl_max_token_length is None: + if self.is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + is_train=True, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # get modulation vectors for Chroma + with accelerator.autocast(), torch.no_grad(): + mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) + + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + if mod_vectors is not None: + mod_vectors.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors): + # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + mod_vectors=mod_vectors, + ) + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + mod_vectors=mod_vectors, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + if self.model_type != "chroma": + model_description = "schnell" if self.is_schnell else "dev" + else: + model_description = "chroma" + return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description) + + def update_metadata(self, metadata, args): + metadata["ss_model_type"] = args.model_type + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/gen_img.py b/gen_img.py index 9427a8940..d0c99bd17 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -352,8 +353,8 @@ def __init__( self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -542,7 +543,7 @@ def __call__( else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -731,7 +732,12 @@ def __call__( num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -793,7 +799,7 @@ def __call__( latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -802,9 +808,16 @@ def __call__( logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -823,6 +836,31 @@ def __call__( text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1825,16 +1863,37 @@ def __getattr__(self, item): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py new file mode 100644 index 000000000..8c14cf6f1 --- /dev/null +++ b/hunyuan_image_minimal_inference.py @@ -0,0 +1,1268 @@ +import argparse +import datetime +import gc +from importlib.util import find_spec +import random +import os +import re +import time +import copy +from types import ModuleType, SimpleNamespace +from typing import Tuple, Optional, List, Any, Dict, Union + +import numpy as np +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image + +from library import hunyuan_image_models, hunyuan_image_text_encoder, hunyuan_image_utils +from library import hunyuan_image_vae +from library.hunyuan_image_vae import HunyuanVAE2D +from library.device_utils import clean_memory_on_device, synchronize_device +from library.safetensors_utils import mem_eff_save_file +from networks import lora_hunyuan_image + + +lycoris_available = find_spec("lycoris") is not None +if lycoris_available: + from lycoris.kohya import create_network_from_weights + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class GenerationSettings: + def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): + self.device = device + self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="HunyuanImage inference script") + + parser.add_argument("--dit", type=str, default=None, help="DiT directory or path") + parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") + parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path") + parser.add_argument("--byt5", type=str, default=None, help="ByT5 Text Encoder 2 directory or path") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument( + "--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5." + ) + parser.add_argument( + "--apg_start_step_ocr", + type=int, + default=38, + help="Starting step for Adaptive Projected Guidance (APG) for image with text. Default is 38. Should be less than infer_steps, usually near the end.", + ) + parser.add_argument( + "--apg_start_step_general", + type=int, + default=5, + help="Starting step for Adaptive Projected Guidance (APG) for general image. Default is 5. Should be less than infer_steps, usually near the beginning.", + ) + parser.add_argument( + "--guidance_rescale", + type=float, + default=0.0, + help="Guidance rescale factor for steps without APG, 0.0 to 1.0. Default is 0.0 (no rescale).", + ) + parser.add_argument( + "--guidance_rescale_apg", + type=float, + default=0.0, + help="Guidance rescale factor for steps with APG, 0.0 to 1.0. Default is 0.0 (no rescale).", + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string") + parser.add_argument("--image_size", type=int, nargs=2, default=[2048, 2048], help="image size, height and width") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps, default is 50") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=5.0, + help="Shift factor for flow matching schedulers. Default is 5.0.", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + + parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders") + parser.add_argument( + "--vae_chunk_size", + type=int, + default=None, # default is None (no chunking) + help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled" + " / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility + help="attention mode", + ) + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", + type=str, + default="images", + choices=["images", "latent", "latent_images"], + help="output type", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument( + "--lycoris", action="store_true", help=f"use lycoris for inference{'' if lycoris_available else ' (not available)'}" + ) + + # arguments for batch and interactive modes + parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") + parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") + + args = parser.parse_args() + + # Validate arguments + if args.from_file and args.interactive: + raise ValueError("Cannot use both --from_file and --interactive at the same time") + + if args.latent_path is None or len(args.latent_path) == 0: + if args.prompt is None and not args.from_file and not args.interactive: + raise ValueError("Either --prompt, --from_file or --interactive must be specified") + + if args.lycoris and not lycoris_available: + raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS") + + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + + return args + + +def parse_prompt_line(line: str) -> Dict[str, Any]: + """Parse a prompt line into a dictionary of argument overrides + + Args: + line: Prompt line with options + + Returns: + Dict[str, Any]: Dictionary of argument overrides + """ + # TODO common function with hv_train_network.line_to_prompt_dict + parts = line.split(" --") + prompt = parts[0].strip() + + # Create dictionary of overrides + overrides = {"prompt": prompt} + + for part in parts[1:]: + if not part.strip(): + continue + option_parts = part.split(" ", 1) + option = option_parts[0].strip() + value = option_parts[1].strip() if len(option_parts) > 1 else "" + + # Map options to argument names + if option == "w": + overrides["image_size_width"] = int(value) + elif option == "h": + overrides["image_size_height"] = int(value) + elif option == "d": + overrides["seed"] = int(value) + elif option == "s": + overrides["infer_steps"] = int(value) + elif option == "g" or option == "l": + overrides["guidance_scale"] = float(value) + elif option == "fs": + overrides["flow_shift"] = float(value) + # elif option == "i": + # overrides["image_path"] = value + # elif option == "im": + # overrides["image_mask_path"] = value + # elif option == "cn": + # overrides["control_path"] = value + elif option == "n": + overrides["negative_prompt"] = value + # elif option == "ci": # control_image_path + # overrides["control_image_path"] = value + + return overrides + + +def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: + """Apply overrides to args + + Args: + args: Original arguments + overrides: Dictionary of overrides + + Returns: + argparse.Namespace: New arguments with overrides applied + """ + args_copy = copy.deepcopy(args) + + for key, value in overrides.items(): + if key == "image_size_width": + args_copy.image_size[1] = value + elif key == "image_size_height": + args_copy.image_size[0] = value + else: + setattr(args_copy, key, value) + + return args_copy + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int]: + """Validate video size and length + + Args: + args: command line arguments + + Returns: + Tuple[int, int]: (height, width) + """ + height = args.image_size[0] + width = args.image_size[1] + + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + return height, width + + +# region Model + + +def load_dit_model( + args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None +) -> hunyuan_image_models.HYImageDiffusionTransformer: + """load DiT model + + Args: + args: command line arguments + device: device to use + dit_weight_dtype: data type for the model weights. None for as-is + + Returns: + qwen_image_model.HYImageDiffusionTransformer: DiT model instance + """ + # If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method) + + loading_device = "cpu" + if args.blocks_to_swap == 0 and not args.lycoris: + loading_device = device + + # load LoRA weights + if not args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0: + lora_weights_list = [] + for lora_weight in args.lora_weight: + logger.info(f"Loading LoRA weight from: {lora_weight}") + lora_sd = load_file(lora_weight) # load on CPU, dtype is as is + # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns) + lora_weights_list.append(lora_sd) + else: + lora_weights_list = None + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled and not args.lycoris: + loading_weight_dtype = None # we will load weights as-is and then optimize to fp8 + + model = hunyuan_image_models.load_hunyuan_image_model( + device, + args.dit, + args.attn_mode, + True, # enable split_attn to trim masked tokens + loading_device, + loading_weight_dtype, + args.fp8_scaled and not args.lycoris, + lora_weights_list=lora_weights_list, + lora_multipliers=args.lora_multiplier, + ) + + # merge LoRA weights + if args.lycoris: + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(lora_hunyuan_image, model, args, device) + + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + return None + + if not args.fp8_scaled: + # simple cast to dit_weight_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + # if args.compile: + # compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + # logger.info( + # f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + # ) + # torch._dynamo.config.cache_size_limit = 32 + # for i in range(len(model.blocks)): + # model.blocks[i] = torch.compile( + # model.blocks[i], + # backend=compile_backend, + # mode=compile_mode, + # dynamic=compile_dynamic.lower() in "true", + # fullgraph=compile_fullgraph.lower() in "true", + # ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + return model + + +def merge_lora_weights( + lora_module: ModuleType, + model: torch.nn.Module, + lora_weights: List[str], + lora_multipliers: List[float], + include_patterns: Optional[List[str]], + exclude_patterns: Optional[List[str]], + device: torch.device, + lycoris: bool = False, + save_merged_model: Optional[str] = None, + converter: Optional[callable] = None, +) -> None: + """merge LoRA weights to the model + + Args: + lora_module: LoRA module, e.g. lora_wan + model: DiT model + lora_weights: paths to LoRA weights + lora_multipliers: multipliers for LoRA weights + include_patterns: regex patterns to include LoRA modules + exclude_patterns: regex patterns to exclude LoRA modules + device: torch.device + lycoris: use LyCORIS + save_merged_model: path to save merged model, if specified, no inference will be performed + converter: Optional[callable] = None + """ + if lora_weights is None or len(lora_weights) == 0: + return + + for i, lora_weight in enumerate(lora_weights): + if lora_multipliers is not None and len(lora_multipliers) > i: + lora_multiplier = lora_multipliers[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + if converter is not None: + weights_sd = converter(weights_sd) + + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_patterns is not None and len(include_patterns) > i: + include_pattern = include_patterns[i] + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + if exclude_patterns is not None and len(exclude_patterns) > i: + original_key_count_ex = len(weights_sd.keys()) + exclude_pattern = exclude_patterns[i] + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info( + f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" + ) + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + if lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=model, + text_encoder=None, + vae=None, + for_inference=True, + ) + lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) + else: + network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) + network.merge_to(None, model, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if save_merged_model: + logger.info(f"Saving merged model to {save_merged_model}") + mem_eff_save_file(model.state_dict(), save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + + +# endregion + + +def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor: + logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}") + + vae.to(device) + with torch.no_grad(): + latent = latent / vae.scaling_factor # scale latent back to original range + pixels = vae.decode(latent.to(device, dtype=vae.dtype)) + pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy) + vae.to("cpu") + + logger.info(f"Decoded. Pixel shape {pixels.shape}") + return pixels[0] # remove batch dimension + + +def prepare_text_inputs( + args: argparse.Namespace, device: torch.device, shared_models: Optional[Dict] = None +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Prepare text-related inputs for T2I: LLM encoding.""" + + # load text encoder: conds_cache holds cached encodings for prompts without padding + conds_cache = {} + vl_device = torch.device("cpu") if args.text_encoder_cpu else device + if shared_models is not None: + tokenizer_vlm = shared_models.get("tokenizer_vlm") + text_encoder_vlm = shared_models.get("text_encoder_vlm") + tokenizer_byt5 = shared_models.get("tokenizer_byt5") + text_encoder_byt5 = shared_models.get("text_encoder_byt5") + + if "conds_cache" in shared_models: # Use shared cache if available + conds_cache = shared_models["conds_cache"] + + # text_encoder is on device (batched inference) or CPU (interactive inference) + else: # Load if not in shared_models + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=True + ) + + # Store original devices to move back later if they were shared. This does nothing if shared_models is None + text_encoder_original_device = text_encoder_vlm.device if text_encoder_vlm else None + + # Ensure text_encoder is not None before proceeding + if not text_encoder_vlm or not tokenizer_vlm or not tokenizer_byt5 or not text_encoder_byt5: + raise ValueError("Text encoder or tokenizer is not loaded properly.") + + # Define a function to move models to device if needed + # This is to avoid moving models if not needed, especially in interactive mode + model_is_moved = False + + def move_models_to_device_if_needed(): + nonlocal model_is_moved + nonlocal shared_models + + if model_is_moved: + return + model_is_moved = True + + logger.info(f"Moving DiT and Text Encoder to appropriate device: {device} or CPU") + if shared_models and "model" in shared_models: # DiT model is shared + if args.blocks_to_swap > 0: + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + model = shared_models["model"] + model.to("cpu") + clean_memory_on_device(device) # clean memory on device before moving models + + text_encoder_vlm.to(vl_device) # If text_encoder_cpu is True, this will be CPU + text_encoder_byt5.to(vl_device) + + logger.info("Encoding prompt with Text Encoder") + + prompt = args.prompt + cache_key = prompt + if cache_key in conds_cache: + embed, mask, embed_byt5, mask_byt5, ocr_mask = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + with torch.no_grad(): + embed, mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(tokenizer_vlm, text_encoder_vlm, prompt) + ocr_mask, embed_byt5, mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds( + tokenizer_byt5, text_encoder_byt5, prompt + ) + embed = embed.cpu() + mask = mask.cpu() + embed_byt5 = embed_byt5.cpu() + mask_byt5 = mask_byt5.cpu() + + conds_cache[cache_key] = (embed, mask, embed_byt5, mask_byt5, ocr_mask) + + negative_prompt = args.negative_prompt + cache_key = negative_prompt + if cache_key in conds_cache: + negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + with torch.no_grad(): + negative_embed, negative_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds( + tokenizer_vlm, text_encoder_vlm, negative_prompt + ) + negative_ocr_mask, negative_embed_byt5, negative_mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds( + tokenizer_byt5, text_encoder_byt5, negative_prompt + ) + negative_embed = negative_embed.cpu() + negative_mask = negative_mask.cpu() + negative_embed_byt5 = negative_embed_byt5.cpu() + negative_mask_byt5 = negative_mask_byt5.cpu() + + conds_cache[cache_key] = (negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask) + + if not (shared_models and "text_encoder_vlm" in shared_models): # if loaded locally + # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8 + del tokenizer_vlm, text_encoder_vlm, tokenizer_byt5, text_encoder_byt5 + gc.collect() # This may force Text Encoder to be freed from GPU memory + else: # if shared, move back to original device (likely CPU) + if text_encoder_vlm: + text_encoder_vlm.to(text_encoder_original_device) + if text_encoder_byt5: + text_encoder_byt5.to(text_encoder_original_device) + + clean_memory_on_device(device) + + arg_c = {"embed": embed, "mask": mask, "embed_byt5": embed_byt5, "mask_byt5": mask_byt5, "ocr_mask": ocr_mask, "prompt": prompt} + arg_null = { + "embed": negative_embed, + "mask": negative_mask, + "embed_byt5": negative_embed_byt5, + "mask_byt5": negative_mask_byt5, + "ocr_mask": negative_ocr_mask, + "prompt": negative_prompt, + } + + return arg_c, arg_null + + +def generate( + args: argparse.Namespace, + gen_settings: GenerationSettings, + shared_models: Optional[Dict] = None, + precomputed_text_data: Optional[Dict] = None, +) -> torch.Tensor: + """main function for generation + + Args: + args: command line arguments + shared_models: dictionary containing pre-loaded models (mainly for DiT) + precomputed_image_data: Optional dictionary with precomputed image data + precomputed_text_data: Optional dictionary with precomputed text data + + Returns: + tuple: (HunyuanVAE2D model (vae) or None, torch.Tensor generated latent) + """ + device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) + + # prepare seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed # set seed to args for saving + + if precomputed_text_data is not None: + logger.info("Using precomputed text data.") + context = precomputed_text_data["context"] + context_null = precomputed_text_data["context_null"] + + else: + logger.info("No precomputed data. Preparing image and text inputs.") + context, context_null = prepare_text_inputs(args, device, shared_models) + + if shared_models is None or "model" not in shared_models: + # load DiT model + model = load_dit_model(args, device, dit_weight_dtype) + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + return None + + if shared_models is not None: + shared_models["model"] = model + else: + # use shared model + logger.info("Using shared DiT model.") + model: hunyuan_image_models.HYImageDiffusionTransformer = shared_models["model"] + model.move_to_device_except_swap_blocks(device) # Handles block swap correctly + model.prepare_block_swap_before_forward() + + return generate_body(args, model, context, context_null, device, seed) + + +def generate_body( + args: Union[argparse.Namespace, SimpleNamespace], + model: hunyuan_image_models.HYImageDiffusionTransformer, + context: Dict[str, Any], + context_null: Optional[Dict[str, Any]], + device: torch.device, + seed: int, +) -> torch.Tensor: + + # set random generator + seed_g = torch.Generator(device="cpu") + seed_g.manual_seed(seed) + + height, width = check_inputs(args) + logger.info(f"Image size: {height}x{width} (HxW), infer_steps: {args.infer_steps}") + + # image generation ###### + + logger.info(f"Prompt: {context['prompt']}") + + embed = context["embed"].to(device, dtype=torch.bfloat16) + mask = context["mask"].to(device, dtype=torch.bfloat16) + embed_byt5 = context["embed_byt5"].to(device, dtype=torch.bfloat16) + mask_byt5 = context["mask_byt5"].to(device, dtype=torch.bfloat16) + ocr_mask = context["ocr_mask"] # list of bool + + if context_null is None: + context_null = context # dummy for unconditional + + negative_embed = context_null["embed"].to(device, dtype=torch.bfloat16) + negative_mask = context_null["mask"].to(device, dtype=torch.bfloat16) + negative_embed_byt5 = context_null["embed_byt5"].to(device, dtype=torch.bfloat16) + negative_mask_byt5 = context_null["mask_byt5"].to(device, dtype=torch.bfloat16) + # negative_ocr_mask = context_null["ocr_mask"] # list of bool + + # Prepare latent variables + num_channels_latents = model.in_channels + shape = (1, num_channels_latents, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR) + latents = randn_tensor(shape, generator=seed_g, device=device, dtype=torch.bfloat16) + + logger.info( + f"Embed: {embed.shape}, embed byt5: {embed_byt5.shape}, negative_embed: {negative_embed.shape}, negative embed byt5: {negative_embed_byt5.shape}, latents: {latents.shape}" + ) + + # Prepare timesteps + timesteps, sigmas = hunyuan_image_utils.get_timesteps_sigmas(args.infer_steps, args.flow_shift, device) + + # Prepare Guider + cfg_guider_ocr = hunyuan_image_utils.AdaptiveProjectedGuidance( + guidance_scale=10.0, + eta=0.0, + adaptive_projected_guidance_rescale=10.0, + adaptive_projected_guidance_momentum=-0.5, + guidance_rescale=args.guidance_rescale_apg, + ) + cfg_guider_general = hunyuan_image_utils.AdaptiveProjectedGuidance( + guidance_scale=10.0, + eta=0.0, + adaptive_projected_guidance_rescale=10.0, + adaptive_projected_guidance_momentum=-0.5, + guidance_rescale=args.guidance_rescale_apg, + ) + + # Denoising loop + do_cfg = args.guidance_scale != 1.0 + # print(f"embed shape: {embed.shape}, mean: {embed.mean()}, std: {embed.std()}") + # print(f"embed_byt5 shape: {embed_byt5.shape}, mean: {embed_byt5.mean()}, std: {embed_byt5.std()}") + # print(f"negative_embed shape: {negative_embed.shape}, mean: {negative_embed.mean()}, std: {negative_embed.std()}") + # print(f"negative_embed_byt5 shape: {negative_embed_byt5.shape}, mean: {negative_embed_byt5.mean()}, std: {negative_embed_byt5.std()}") + # print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}") + # print(f"mask shape: {mask.shape}, sum: {mask.sum()}") + # print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}") + # print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}") + # print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}") + + autocast_enabled = args.fp8 + + with tqdm(total=len(timesteps), desc="Denoising steps") as pbar: + for i, t in enumerate(timesteps): + t_expand = t.expand(latents.shape[0]).to(torch.int64) + + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): + noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5) + + if do_cfg: + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): + uncond_noise_pred = model( + latents, t_expand, negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5 + ) + noise_pred = hunyuan_image_utils.apply_classifier_free_guidance( + noise_pred, + uncond_noise_pred, + ocr_mask[0], + args.guidance_scale, + i, + apg_start_step_ocr=args.apg_start_step_ocr, + apg_start_step_general=args.apg_start_step_general, + cfg_guider_ocr=cfg_guider_ocr, + cfg_guider_general=cfg_guider_general, + guidance_rescale=args.guidance_rescale, + ) + + # ensure latents dtype is consistent + latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype) + + pbar.update() + + return latents + + +def get_time_flag(): + return datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S-%f")[:-3] + + +def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: + """Save latent to file + + Args: + latent: Latent tensor + args: command line arguments + height: height of frame + width: width of frame + + Returns: + str: Path to saved latent file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + + latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seed}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "infer_steps": f"{args.infer_steps}", + # "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + "guidance_scale": f"{args.guidance_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + + sd = {"latent": latent.contiguous()} + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + + return latent_path + + +def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save images to directory + + Args: + sample: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved images directory + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + image_name = f"{time_flag}_{seed}{original_name}" + + x = torch.clamp(sample, -1.0, 1.0) + x = ((x + 1.0) * 127.5).to(torch.uint8).cpu().numpy() + x = x.transpose(1, 2, 0) # C, H, W -> H, W, C + + image = Image.fromarray(x) + image.save(os.path.join(save_path, f"{image_name}.png")) + + logger.info(f"Sample images saved to: {save_path}/{image_name}") + + return f"{save_path}/{image_name}" + + +def save_output( + args: argparse.Namespace, + vae: HunyuanVAE2D, + latent: torch.Tensor, + device: torch.device, + original_base_name: Optional[str] = None, +) -> None: + """save output + + Args: + args: command line arguments + vae: VAE model + latent: latent tensor + device: device to use + original_base_name: original base name (if latents are loaded from files) + """ + height, width = latent.shape[-2], latent.shape[-1] # BCTHW + height *= hunyuan_image_vae.VAE_SCALE_FACTOR + width *= hunyuan_image_vae.VAE_SCALE_FACTOR + # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}") + if args.output_type == "latent" or args.output_type == "latent_images": + # save latent + save_latent(latent, args, height, width) + if args.output_type == "latent": + return + + if vae is None: + logger.error("VAE is None, cannot decode latents for saving video/images.") + return + + if latent.ndim == 2: # S,C. For packed latents from other inference scripts + latent = latent.unsqueeze(0) + height, width = check_inputs(args) # Get height/width from args + latent = latent.view( + 1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR + ) + + image = decode_latent(vae, latent, device) + + if args.output_type == "images" or args.output_type == "latent_images": + # save images + if original_base_name is None: + original_name = "" + else: + original_name = f"_{original_base_name}" + save_images(image, args, original_name) + + +def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: + """Process multiple prompts for batch mode + + Args: + prompt_lines: List of prompt lines + base_args: Base command line arguments + + Returns: + List[Dict]: List of prompt data dictionaries + """ + prompts_data = [] + + for line in prompt_lines: + line = line.strip() + if not line or line.startswith("#"): # Skip empty lines and comments + continue + + # Parse prompt line and create override dictionary + prompt_data = parse_prompt_line(line) + logger.info(f"Parsed prompt data: {prompt_data}") + prompts_data.append(prompt_data) + + return prompts_data + + +def load_shared_models(args: argparse.Namespace) -> Dict: + """Load shared models for batch processing or interactive mode. + Models are loaded to CPU to save memory. VAE is NOT loaded here. + DiT model is also NOT loaded here, handled by process_batch_prompts or generate. + + Args: + args: Base command line arguments + + Returns: + Dict: Dictionary of shared models (text/image encoders) + """ + shared_models = {} + # Load text encoders to CPU + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device="cpu", disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device="cpu", disable_mmap=True + ) + shared_models["tokenizer_vlm"] = tokenizer_vlm + shared_models["text_encoder_vlm"] = text_encoder_vlm + shared_models["tokenizer_byt5"] = tokenizer_byt5 + shared_models["text_encoder_byt5"] = text_encoder_byt5 + return shared_models + + +def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: + """Process multiple prompts with model reuse and batched precomputation + + Args: + prompts_data: List of prompt data dictionaries + args: Base command line arguments + """ + if not prompts_data: + logger.warning("No valid prompts found") + return + + gen_settings = get_generation_settings(args) + dit_weight_dtype = gen_settings.dit_weight_dtype + device = gen_settings.device + + # 1. Prepare VAE + logger.info("Loading VAE for batch generation...") + vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) + vae_for_batch.eval() + + all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first + for prompt_args in all_prompt_args_list: + check_inputs(prompt_args) # Validate each prompt's height/width + + # 2. Precompute Text Data (Text Encoder) + logger.info("Loading Text Encoder for batch text preprocessing...") + + # Text Encoder loaded to CPU by load_text_encoder + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm_batch = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device="cpu", disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5_batch = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device="cpu", disable_mmap=True + ) + + # Text Encoder to device for this phase + vl_device = torch.device("cpu") if args.text_encoder_cpu else device + text_encoder_vlm_batch.to(vl_device) # Moved into prepare_text_inputs logic + text_encoder_byt5_batch.to(vl_device) + + all_precomputed_text_data = [] + conds_cache_batch = {} + + logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...") + temp_shared_models_txt = { + "tokenizer_vlm": tokenizer_vlm, + "text_encoder_vlm": text_encoder_vlm_batch, # on GPU if not text_encoder_cpu + "tokenizer_byt5": tokenizer_byt5, + "text_encoder_byt5": text_encoder_byt5_batch, # on GPU if not text_encoder_cpu + "conds_cache": conds_cache_batch, + } + + for i, prompt_args_item in enumerate(all_prompt_args_list): + logger.info(f"Text preprocessing for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + + # prepare_text_inputs will move text_encoders to device temporarily + context, context_null = prepare_text_inputs(prompt_args_item, device, temp_shared_models_txt) + text_data = {"context": context, "context_null": context_null} + all_precomputed_text_data.append(text_data) + + # Models should be removed from device after prepare_text_inputs + del tokenizer_vlm, text_encoder_vlm_batch, tokenizer_byt5, text_encoder_byt5_batch, temp_shared_models_txt, conds_cache_batch + gc.collect() # Force cleanup of Text Encoder from GPU memory + clean_memory_on_device(device) + + # 3. Load DiT Model once + logger.info("Loading DiT model for batch generation...") + # Use args from the first prompt for DiT loading (LoRA etc. should be consistent for a batch) + first_prompt_args = all_prompt_args_list[0] + dit_model = load_dit_model(first_prompt_args, device, dit_weight_dtype) # Load directly to target device if possible + + if first_prompt_args.save_merged_model: + logger.info("Merged DiT model saved. Skipping generation.") + + shared_models_for_generate = {"model": dit_model} # Pass DiT via shared_models + + all_latents = [] + + logger.info("Generating latents for all prompts...") + with torch.no_grad(): + for i, prompt_args_item in enumerate(all_prompt_args_list): + current_text_data = all_precomputed_text_data[i] + height, width = check_inputs(prompt_args_item) # Get height/width for each prompt + + logger.info(f"Generating latent for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + try: + # generate is called with precomputed data, so it won't load Text Encoders. + # It will use the DiT model from shared_models_for_generate. + latent = generate(prompt_args_item, gen_settings, shared_models_for_generate, current_text_data) + + if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier + continue + + # Save latent if needed (using data from precomputed_image_data for H/W) + if prompt_args_item.output_type in ["latent", "latent_images"]: + save_latent(latent, prompt_args_item, height, width) + + all_latents.append(latent) + except Exception as e: + logger.error(f"Error generating latent for prompt: {prompt_args_item.prompt}. Error: {e}", exc_info=True) + all_latents.append(None) # Add placeholder for failed generations + continue + + # Free DiT model + logger.info("Releasing DiT model from memory...") + if args.blocks_to_swap > 0: + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + + del shared_models_for_generate["model"] + del dit_model + clean_memory_on_device(device) + synchronize_device(device) # Ensure memory is freed before loading VAE for decoding + + # 4. Decode latents and save outputs (using vae_for_batch) + if args.output_type != "latent": + logger.info("Decoding latents to videos/images using batched VAE...") + vae_for_batch.to(device) # Move VAE to device for decoding + + for i, latent in enumerate(all_latents): + if latent is None: # Skip failed generations + logger.warning(f"Skipping decoding for prompt {i+1} due to previous error.") + continue + + current_args = all_prompt_args_list[i] + logger.info(f"Decoding output {i+1}/{len(all_latents)} for prompt: {current_args.prompt}") + + # if args.output_type is "latent_images", we already saved latent above. + # so we skip saving latent here. + if current_args.output_type == "latent_images": + current_args.output_type = "images" + + # save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1). + # latent[0] is correct if generate returns it with batch dim. + # The latent from generate is (1, C, T, H, W) + save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch + + vae_for_batch.to("cpu") # Move VAE back to CPU + + del vae_for_batch + clean_memory_on_device(device) + + +def process_interactive(args: argparse.Namespace) -> None: + """Process prompts in interactive mode + + Args: + args: Base command line arguments + """ + gen_settings = get_generation_settings(args) + device = gen_settings.device + shared_models = load_shared_models(args) + shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode + + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) + vae.eval() + + print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):") + + try: + import prompt_toolkit + except ImportError: + logger.warning("prompt_toolkit not found. Using basic input instead.") + prompt_toolkit = None + + if prompt_toolkit: + session = prompt_toolkit.PromptSession() + + def input_line(prompt: str) -> str: + return session.prompt(prompt) + + else: + + def input_line(prompt: str) -> str: + return input(prompt) + + try: + while True: + try: + line = input_line("> ") + if not line.strip(): + continue + if len(line.strip()) == 1 and line.strip() in ["\x04", "\x1a"]: # Ctrl+D or Ctrl+Z with prompt_toolkit + raise EOFError # Exit on Ctrl+D or Ctrl+Z + + # Parse prompt + prompt_data = parse_prompt_line(line) + prompt_args = apply_overrides(args, prompt_data) + + # Generate latent + # For interactive, precomputed data is None. shared_models contains text encoders. + latent = generate(prompt_args, gen_settings, shared_models) + + # # If not one_frame_inference, move DiT model to CPU after generation + # if prompt_args.blocks_to_swap > 0: + # logger.info("Waiting for 5 seconds to finish block swap") + # time.sleep(5) + # model = shared_models.get("model") + # model.to("cpu") # Move DiT model to CPU after generation + + # Save latent and video + # returned_vae from generate will be used for decoding here. + save_output(prompt_args, vae, latent, device) + + except KeyboardInterrupt: + print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") + continue + + except EOFError: + print("\nExiting interactive mode") + + +def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: + device = torch.device(args.device) + + dit_weight_dtype = torch.bfloat16 # default + if args.fp8_scaled: + dit_weight_dtype = None # various precision weights, so don't cast to specific dtype + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + + logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") + + gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) + return gen_settings + + +def main(): + # Parse arguments + args = parse_args() + + # Check if latents are provided + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + args.device = device + + if latents_mode: + # Original latent decode mode + original_base_names = [] + latents_list = [] + seeds = [] + + # assert len(args.latent_path) == 1, "Only one latent path is supported for now" + + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + if "height" in metadata and "width" in metadata: + height = int(metadata["height"]) + width = int(metadata["width"]) + args.image_size = [height, width] + + seeds.append(seed) + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + + if latents.ndim == 5: # [BCTHW] + latents = latents.squeeze(0) # [CTHW] + + latents_list.append(latents) + + # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape + + for i, latent in enumerate(latents_list): + args.seed = seeds[i] + + vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size) + vae.eval() + save_output(args, vae, latent, device, original_base_names[i]) + + elif args.from_file: + # Batch mode from file + + # Read prompts from file + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_lines = f.readlines() + + # Process prompts + prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + process_batch_prompts(prompts_data, args) + + elif args.interactive: + # Interactive mode + process_interactive(args) + + else: + # Single prompt mode (original behavior) + + # Generate latent + gen_settings = get_generation_settings(args) + + # For single mode, precomputed data is None, shared_models is None. + # generate will load all necessary models (Text Encoders, DiT). + latent = generate(args, gen_settings) + # print(f"Generated latent shape: {latent.shape}") + # if args.save_merged_model: + # return + + clean_memory_on_device(device) + + # Save latent and video + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) + vae.eval() + save_output(args, vae, latent, device) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py new file mode 100644 index 000000000..9ab351ea2 --- /dev/null +++ b/hunyuan_image_train_network.py @@ -0,0 +1,717 @@ +import argparse +import copy +import gc +from typing import Any, Optional, Union, cast +import os +import time +from types import SimpleNamespace + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from accelerate import Accelerator, PartialState + +from library import flux_utils, hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import train_network +from library import ( + flux_train_utils, + hunyuan_image_models, + hunyuan_image_text_encoder, + hunyuan_image_utils, + hunyuan_image_vae, + sd3_train_utils, + strategy_base, + strategy_hunyuan_image, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sampling + + +# TODO commonize with flux_utils +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + dit: hunyuan_image_models.HYImageDiffusionTransformer, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + dit = accelerator.unwrap_model(dit) + dit = cast(hunyuan_image_models.HYImageDiffusionTransformer, dit) + dit.switch_block_swap_for_inference() + if text_encoders is not None: + text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + dit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + dit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + dit.switch_block_swap_for_training() + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + dit: hunyuan_image_models.HYImageDiffusionTransformer, + text_encoders: Optional[list[nn.Module]], + vae: hunyuan_image_vae.HunyuanVAE2D, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + cfg_scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + prompt: str = prompt_dict.get("prompt", "") + flow_shift: float = prompt_dict.get("flow_shift", 5.0) + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + if cfg_scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") + logger.info(f"flow_shift: {flow_shift}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + # print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + # print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + vl_embed, vl_mask, byt5_embed, byt5_mask, ocr_mask = encode_prompt(prompt) + arg_c = { + "embed": vl_embed, + "mask": vl_mask, + "embed_byt5": byt5_embed, + "mask_byt5": byt5_mask, + "ocr_mask": ocr_mask, + "prompt": prompt, + } + + # encode negative prompts + if cfg_scale != 1.0: + neg_vl_embed, neg_vl_mask, neg_byt5_embed, neg_byt5_mask, neg_ocr_mask = encode_prompt(negative_prompt) + arg_c_null = { + "embed": neg_vl_embed, + "mask": neg_vl_mask, + "embed_byt5": neg_byt5_embed, + "mask_byt5": neg_byt5_mask, + "ocr_mask": neg_ocr_mask, + "prompt": negative_prompt, + } + else: + arg_c_null = None + + gen_args = SimpleNamespace( + image_size=(height, width), + infer_steps=sample_steps, + flow_shift=flow_shift, + guidance_scale=cfg_scale, + fp8=args.fp8_scaled, + apg_start_step_ocr=38, + apg_start_step_general=5, + guidance_rescale=0.0, + guidance_rescale_apg=0.0, + ) + + from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import + + dit_is_training = dit.training + dit.eval() + x = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed) + if dit_is_training: + dit.train() + clean_memory_on_device(accelerator.device) + + # latent to image + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with torch.no_grad(): + x = x / vae.scaling_factor + x = vae.decode(x.to(vae.device, dtype=vae.dtype)) + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +# endregion + + +class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_swapping_blocks: bool = False + self.rotary_pos_emb_cache = {} + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.mixed_precision == "fp16": + logger.warning( + "mixed_precision bf16 is recommended for HunyuanImage-2.1 / HunyuanImage-2.1ではmixed_precision bf16が推奨されます" + ) + + if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled: + logger.warning( + "fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください" + ) + if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet): + logger.info( + "fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます" + ) + args.fp8_base = False + args.fp8_base_unet = False + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + + vl_dtype = torch.float8_e4m3fn if args.fp8_vl else torch.bfloat16 + vl_device = "cpu" # loading to cpu and move to gpu later in cache_text_encoder_outputs_if_needed + _, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors + ) + _, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors + ) + + vae = hunyuan_image_vae.load_vae( + args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors, chunk_size=args.vae_chunk_size + ) + vae.to(dtype=torch.float16) # VAE is always fp16 + vae.eval() + + model_version = hunyuan_image_utils.MODEL_VERSION_2_1 + return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later + + def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]: + if args.cache_text_encoder_outputs: + logger.info("Replace text encoders with dummy models to save memory") + + # This doesn't free memory, so we move text encoders to meta device in cache_text_encoder_outputs_if_needed + text_encoders = [flux_utils.dummy_clip_l() for _ in text_encoders] + clean_memory_on_device(accelerator.device) + gc.collect() + + loading_dtype = None if args.fp8_scaled else weight_dtype + loading_device = "cpu" if self.is_swapping_blocks else accelerator.device + + attn_mode = "torch" + if args.xformers: + attn_mode = "xformers" + if args.attn_mode is not None: + attn_mode = args.attn_mode + + logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}") + model = hunyuan_image_models.load_hunyuan_image_model( + accelerator.device, + args.pretrained_model_name_or_path, + attn_mode, + args.split_attn, + loading_device, + loading_dtype, + args.fp8_scaled, + ) + + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True) + + return model, text_encoders + + def get_tokenize_strategy(self, args): + return strategy_hunyuan_image.HunyuanImageTokenizeStrategy(args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy): + return [tokenize_strategy.vlm_tokenizer, tokenize_strategy.byt5_tokenizer] + + def get_latents_caching_strategy(self, args): + return strategy_hunyuan_image.HunyuanImageLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + + def get_text_encoding_strategy(self, args): + return strategy_hunyuan_image.HunyuanImageTextEncodingStrategy() + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders + + def get_text_encoders_train_flags(self, args, text_encoders): + # HunyuanImage-2.1 does not support training VLM or byT5 + return [False, False] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_hunyuan_image.HunyuanImageTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + vlm_device = "cpu" if args.text_encoder_cpu else accelerator.device + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae to cpu to save memory") + org_vae_device = vae.device + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + logger.info(f"move text encoders to {vlm_device} to encode and cache text encoder outputs") + text_encoders[0].to(vlm_device) + text_encoders[1].to(vlm_device) + + # VLM (bf16) and byT5 (fp16) are used for encoding, so we cannot use autocast here + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_hunyuan_image.HunyuanImageTokenizeStrategy = ( + strategy_base.TokenizeStrategy.get_strategy() + ) + text_encoding_strategy: strategy_hunyuan_image.HunyuanImageTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # text encoders are not needed for training, so we move to meta device + logger.info("move text encoders to meta device to save memory") + text_encoders = [te.to("meta") for te in text_encoders] + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae back to original device") + vae.to(org_vae_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(vlm_device) + text_encoders[1].to(vlm_device) + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sample_images(accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, vae: hunyuan_image_vae.HunyuanVAE2D, images): + return vae.encode(images).sample() + + def shift_scale_latents(self, args, latents): + # for encoding, we need to scale the latents + return latents * hunyuan_image_vae.LATENT_SCALING_FACTOR + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: hunyuan_image_models.HYImageDiffusionTransformer, + network, + weight_dtype, + train_unet, + is_train=True, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, _, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + # bfloat16 is too low precision for 0-1000 TODO fix get_noisy_model_input_and_timesteps + timesteps = (sigmas[:, 0, 0, 0] * 1000).to(torch.int64) + # print( + # f"timestep: {timesteps}, noisy_model_input shape: {noisy_model_input.shape}, mean: {noisy_model_input.mean()}, std: {noisy_model_input.std()}" + # ) + + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + # ocr_mask is for inference only, so it is not used here + vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = text_encoder_conds + + # print(f"embed shape: {vlm_embed.shape}, mean: {vlm_embed.mean()}, std: {vlm_embed.std()}") + # print(f"embed_byt5 shape: {byt5_embed.shape}, mean: {byt5_embed.mean()}, std: {byt5_embed.std()}") + # print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}") + # print(f"mask shape: {vlm_mask.shape}, sum: {vlm_mask.sum()}") + # print(f"mask_byt5 shape: {byt5_mask.shape}, sum: {byt5_mask.sum()}") + with torch.set_grad_enabled(is_train), accelerator.autocast(): + model_pred = unet( + noisy_model_input, timesteps, vlm_embed, vlm_mask, byt5_embed, byt5_mask # , self.rotary_pos_emb_cache + ) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss + target = noise - latents + + # differential output preservation is not used for HunyuanImage-2.1 currently + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1").to_metadata_dict() + + def update_metadata(self, metadata, args): + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # do not support text encoder training for HunyuanImage-2.1 + pass + + def cast_text_encoder(self, args): + return False # VLM is bf16, byT5 is fp16, so do not cast to other dtype + + def cast_vae(self, args): + return False # VAE is fp16, so do not cast to other dtype + + def cast_unet(self, args): + return not args.fp8_scaled # if fp8_scaled is used, do not cast to other dtype + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + # fp8 text encoder for HunyuanImage-2.1 is not supported currently + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + model: hunyuan_image_models.HYImageDiffusionTransformer = unet + model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(model).prepare_block_swap_before_forward() + + return model + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + + parser.add_argument( + "--text_encoder", + type=str, + help="path to Qwen2.5-VL (*.sft or *.safetensors), should be bfloat16 / Qwen2.5-VLのパス(*.sftまたは*.safetensors)、bfloat16が前提", + ) + parser.add_argument( + "--byt5", + type=str, + help="path to byt5 (*.sft or *.safetensors), should be float16 / byt5のパス(*.sftまたは*.safetensors)、float16が前提", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="raw", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling). Default is raw unlike FLUX.1." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。デフォルトはFLUX.1とは異なりrawです。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=5.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 5.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは5.0。", + ) + parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") + parser.add_argument("--fp8_vl", action="store_true", help="Use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する") + parser.add_argument( + "--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する" + ) + parser.add_argument( + "--vae_chunk_size", + type=int, + default=None, # default is None (no chunking) + help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled" + " / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。", + ) + + parser.add_argument( + "--attn_mode", + choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility + default=None, + help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa." + " / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。", + ) + parser.add_argument( + "--split_attn", + action="store_true", + help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + + trainer = HunyuanImageNetworkTrainer() + trainer.train(args) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index bdfc32ced..b5afa236b 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -2,6 +2,32 @@ import torch from transformers import Adafactor +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -48,7 +74,7 @@ def adafactor_step_param(self, p, group): lr = Adafactor._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] @@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group): p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: p.copy_(p_data_fp32) @@ -101,6 +132,7 @@ def adafactor_step(self, closure=None): return loss + def patch_adafactor_fused(optimizer: Adafactor): optimizer.step_param = adafactor_step_param.__get__(optimizer) optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/attention.py b/library/attention.py new file mode 100644 index 000000000..d3b8441e2 --- /dev/null +++ b/library/attention.py @@ -0,0 +1,260 @@ +# Unified attention function supporting various implementations + +from dataclasses import dataclass +import torch +from typing import Optional, Union + +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.flash_attn_interface import flash_attn_func +except ImportError: + flash_attn = None + flash_attn_varlen_func = None + _flash_attn_forward = None + flash_attn_func = None + +try: + from sageattention import sageattn_varlen, sageattn +except ImportError: + sageattn_varlen = None + sageattn = None + +try: + import xformers.ops as xops +except ImportError: + xops = None + + +@dataclass +class AttentionParams: + attn_mode: Optional[str] = None + split_attn: bool = False + img_len: Optional[int] = None + attention_mask: Optional[torch.Tensor] = None + seqlens: Optional[torch.Tensor] = None + cu_seqlens: Optional[torch.Tensor] = None + max_seqlen: Optional[int] = None + + @staticmethod + def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams": + return AttentionParams(attn_mode, split_attn) + + @staticmethod + def create_attention_params_from_mask( + attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor] + ) -> "AttentionParams": + if attention_mask is None: + # No attention mask provided: assume all tokens are valid + return AttentionParams(attn_mode, split_attn, None, None, None, None, None) + else: + # Note: attention_mask is only for text tokens, not including image tokens + seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B] + max_seqlen = attention_mask.shape[1] + img_len + + if split_attn: + # cu_seqlens is not needed for split attention + return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen) + + # Convert attention mask to cumulative sequence lengths for flash attention + batch_size = attention_mask.shape[0] + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device) + for i in range(batch_size): + cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query + cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query + + # Expand attention mask to include image tokens + attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L] + + if attn_mode == "xformers": + seqlens_list = seqlens.cpu().tolist() + attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + seqlens_list, seqlens_list, device=attention_mask.device + ) + elif attn_mode == "torch": + attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L] + + return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen) + + +def attention( + qkv_or_q: Union[torch.Tensor, list], + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + attn_params: Optional[AttentionParams] = None, + drop_rate: float = 0.0, +) -> torch.Tensor: + """ + Compute scaled dot-product attention with variable sequence lengths. + + Handles batches with different sequence lengths by splitting and + processing each sequence individually. + + Args: + qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors. + k: Key tensor [B, L, H, D]. + v: Value tensor [B, L, H, D]. + attn_param: Attention parameters including mask and sequence lengths. + drop_rate: Attention dropout rate. + + Returns: + Attention output tensor [B, L, H*D]. + """ + if isinstance(qkv_or_q, list): + q, k, v = qkv_or_q + q: torch.Tensor = q + qkv_or_q.clear() + del qkv_or_q + else: + q: torch.Tensor = qkv_or_q + del qkv_or_q + assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor" + if attn_params is None: + attn_params = AttentionParams.create_attention_params("torch", False) + + # If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence + seqlen_trimmed = False + if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None: + if torch.all(attn_params.seqlens == attn_params.seqlens[0]): + seqlen = attn_params.seqlens[0].item() + q = q[:, :seqlen] + k = k[:, :seqlen] + v = v[:, :seqlen] + max_seqlen = attn_params.max_seqlen + attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify + attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding + seqlen_trimmed = True + + # Determine tensor layout based on attention implementation + if attn_params.attn_mode == "torch" or ( + attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None) + ): + transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length + # pad on sequence length dimension + pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0) + else: + transpose_fn = lambda x: x # [B, L, H, D] for other implementations + # pad on sequence length dimension + pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0) + + # Process each batch element with its valid sequence lengths + if attn_params.split_attn: + if attn_params.seqlens is None: + # If no seqlens provided, assume all tokens are valid + attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify + attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device) + attn_params.max_seqlen = q.shape[1] + q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))] + k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))] + v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))] + else: + q = transpose_fn(q) + k = transpose_fn(k) + v = transpose_fn(v) + + if attn_params.attn_mode == "torch": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D + x = torch.cat(x, dim=0) + del q, k, v + + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate) + del q, k, v + + elif attn_params.attn_mode == "xformers": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D + x = torch.cat(x, dim=0) + del q, k, v + + else: + x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate) + del q, k, v + + elif attn_params.attn_mode == "sageattn": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + # HND seems to cause an error + x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D + x = torch.cat(x, dim=0) + del q, k, v + elif attn_params.cu_seqlens is None: # all tokens are valid + x = sageattn(q, k, v) # B, L, H, D. No dropout support + del q, k, v + else: + # Reshape to [(bxs), a, d] + batch_size, seqlen = q.shape[0], q.shape[1] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D] + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D] + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D] + + # Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support + x = sageattn_varlen( + q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen + ) + del q, k, v + + # Reshape x with shape [(bxs), a, d] to [b, s, a, d] + x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D + + elif attn_params.attn_mode == "flash": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + # HND seems to cause an error + x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D + x = torch.cat(x, dim=0) + del q, k, v + elif attn_params.cu_seqlens is None: # all tokens are valid + x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D + del q, k, v + else: + # Reshape to [(bxs), a, d] + batch_size, seqlen = q.shape[0], q.shape[1] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D] + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D] + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D] + + # Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv + x = flash_attn_varlen_func( + q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate + ) + del q, k, v + + # Reshape x with shape [(bxs), a, d] to [b, s, a, d] + x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D + + else: + # Currently only PyTorch SDPA and xformers are implemented + raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}") + + x = transpose_fn(x) # [B, L, H, D] + x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D] + + if seqlen_trimmed: + x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen + + return x diff --git a/library/chroma_models.py b/library/chroma_models.py new file mode 100644 index 000000000..d5ac1f39e --- /dev/null +++ b/library/chroma_models.py @@ -0,0 +1,744 @@ +# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py +# and modified +# licensed under Apache License 2.0 + +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.nn.functional as F +import torch.utils.checkpoint as ckpt + +from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux +from . import custom_offloading_utils + + +def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks): + """ + Distributes slices of the tensor into the block_dict as ModulationOut objects. + + Args: + tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. + """ + batch_size, vectors, dim = tensor.shape + + block_dict = {} + + # HARD CODED VALUES! lookup table for the generated vectors + # TODO: move this into chroma config! + # Add 38 single mod blocks + for i in range(depth_single_blocks): + key = f"single_blocks.{i}.modulation.lin" + block_dict[key] = None + + # Add 19 image double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.img_mod.lin" + block_dict[key] = None + + # Add 19 text double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.txt_mod.lin" + block_dict[key] = None + + # Add the final layer + block_dict["final_layer.adaLN_modulation.1"] = None + # 6.2b version + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None + + idx = 0 # Index to keep track of the vector slices + + for key in block_dict.keys(): + if "single_blocks" in key: + # Single block: 1 ModulationOut + block_dict[key] = ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + idx += 3 # Advance by 3 vectors + + elif "img_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "txt_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "final_layer" in key: + # Final layer: 1 ModulationOut + block_dict[key] = [ + tensor[:, idx : idx + 1, :], + tensor[:, idx + 1 : idx + 2, :], + ] + idx += 2 # Advance by 3 vectors + + return block_dict + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self): + for layer in self.layers: + layer.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + for layer in self.layers: + layer.disable_gradient_checkpointing() + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +def _modulation_shift_scale_fn(x, scale, shift): + return (1 + scale) * x + shift + + +def _modulation_gate_fn(x, gate, gate_params): + return x + gate * gate_params + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + return _modulation_gate_fn(x, gate, gate_params) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward( + self, + img: Tensor, + txt: Tensor, + pe: list[Tensor], + distill_vec: list[ModulationOut], + txt_seq_len: Tensor, + ) -> tuple[Tensor, Tensor]: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + # replaced with compiled fn + # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift) + img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated + + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del img_qkv + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + # replaced with compiled fn + # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift) + txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated + + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del txt_qkv + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention: we split the batch into each element + max_txt_len = torch.max(txt_seq_len).item() + img_len = img_q.shape[-2] # max 64 + txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors + txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0)) + txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0)) + img_q = list(torch.chunk(img_q, img_q.shape[0], dim=0)) + img_k = list(torch.chunk(img_k, img_k.shape[0], dim=0)) + img_v = list(torch.chunk(img_v, img_v.shape[0], dim=0)) + txt_attn = [] + img_attn = [] + for i in range(txt.shape[0]): + txt_q[i] = txt_q[i][:, :, : txt_seq_len[i]] + q = torch.cat((img_q[i], txt_q[i]), dim=2) + txt_q[i] = None + img_q[i] = None + + txt_k[i] = txt_k[i][:, :, : txt_seq_len[i]] + k = torch.cat((img_k[i], txt_k[i]), dim=2) + txt_k[i] = None + img_k[i] = None + + txt_v[i] = txt_v[i][:, :, : txt_seq_len[i]] + v = torch.cat((img_v[i], txt_v[i]), dim=2) + txt_v[i] = None + img_v[i] = None + + attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D) + del q, k, v + img_attn_i = attn[:, :img_len, :] + txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) + txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :] + del attn + txt_attn.append(txt_attn_i) + img_attn.append(img_attn_i) + + txt_attn = torch.cat(txt_attn, dim=0) + img_attn = torch.cat(img_attn, dim=0) + + # q = torch.cat((txt_q, img_q), dim=2) + # k = torch.cat((txt_k, img_k), dim=2) + # v = torch.cat((txt_v, img_v), dim=2) + + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + # txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + # replaced with compiled fn + # img = img + img_mod1.gate * self.img_attn.proj(img_attn) + # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + del img_attn, img_mod1 + img = self.modulation_gate_fn( + img, + img_mod2.gate, + self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)), + ) + del img_mod2 + + # calculate the txt blocks + # replaced with compiled fn + # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + del txt_attn, txt_mod1 + txt = self.modulation_gate_fn( + txt, + txt_mod2.gate, + self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)), + ) + del txt_mod2 + + return img, txt + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + distill_vec: list[ModulationOut], + txt_seq_len: Tensor, + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, txt_seq_len, use_reentrant=False) + else: + return self._forward(img, txt, pe, distill_vec, txt_seq_len) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + + self.gradient_checkpointing = False + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + return _modulation_gate_fn(x, gate, gate_params) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: + mod = distill_vec + # replaced with compiled fn + # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + del x_mod + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del qkv + q, k = self.norm(q, k, v) + + # # compute attention + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + + # compute attention: we split the batch into each element + max_txt_len = torch.max(txt_seq_len).item() + img_len = q.shape[-2] - max_txt_len + q = list(torch.chunk(q, q.shape[0], dim=0)) + k = list(torch.chunk(k, k.shape[0], dim=0)) + v = list(torch.chunk(v, v.shape[0], dim=0)) + attn = [] + for i in range(x.size(0)): + q[i] = q[i][:, :, : img_len + txt_seq_len[i]] + k[i] = k[i][:, :, : img_len + txt_seq_len[i]] + v[i] = v[i][:, :, : img_len + txt_seq_len[i]] + attn_trimmed = attention(q[i], k[i], v[i], pe=pe[i : i + 1, :, : img_len + txt_seq_len[i]], attn_mask=None) + q[i] = None + k[i] = None + v[i] = None + + attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) + attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed + del attn_trimmed + attn.append(attn_i) + + attn = torch.cat(attn, dim=0) + + # compute activation in mlp stream, cat again and run second linear layer + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) + del attn, mlp + # replaced with compiled fn + # return x + mod.gate * output + return self.modulation_gate_fn(x, mod.gate, output) + + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: + if self.training and self.gradient_checkpointing: + return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, use_reentrant=False) + else: + return self._forward(x, pe, distill_vec, txt_seq_len) + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor: + shift, scale = distill_vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + # replaced with compiled fn + # x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :]) + x = self.linear(x) + return x + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=64, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(Flux): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + nn.Module.__init__(self) + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer( + self.hidden_size, + 1, + self.out_channels, + ) + + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + self.blocks_to_swap = None + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # Initialize properties required by Flux parent class + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def get_model_type(self) -> str: + return "chroma" + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.distilled_guidance_layer.enable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print(f"Chroma: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.distilled_guidance_layer.disable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("Chroma: Gradient checkpointing disabled.") + + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + # We extract this logic from forward to clarify the propagation of the gradients + # original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195 + + # print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}") + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(batch_size, 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + + mod_vectors = self.distilled_guidance_layer(input_vec) + return mod_vectors + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + attn_padding: int = 1, + mod_vectors: Tensor | None = None, + ) -> Tensor: + # print( + # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" + # ) + # print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}") + # print(f"timesteps: {timesteps}, guidance: {guidance}") + + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + txt = self.txt_in(txt) + + if mod_vectors is None: # fallback to the original logic + with torch.no_grad(): + mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0]) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + # calculate text length for each batch instead of masking + txt_emb_len = txt.shape[1] + txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, ) + txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) + max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch + # print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}") + + # trim txt embedding to the text length + txt = txt[:, :max_txt_len, :] + + # create positional encoding for the text and image + ids = torch.cat((img_ids, txt_ids[:, :max_txt_len]), dim=1) # reverse order of ids for faster attention + pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 + + for i, block in enumerate(self.double_blocks): + if self.blocks_to_swap: + self.offloader_double.wait_for_block(i) + + # the guidance replaced by FFN output + img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin") + txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin") + double_mod = [img_mod, txt_mod] + del img_mod, txt_mod + + img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len) + del double_mod + + if self.blocks_to_swap: + self.offloader_double.submit_move_blocks(self.double_blocks, i) + + img = torch.cat((img, txt), 1) + del txt + + for i, block in enumerate(self.single_blocks): + if self.blocks_to_swap: + self.offloader_single.wait_for_block(i) + + single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin") + img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len) + del single_mod + + if self.blocks_to_swap: + self.offloader_single.submit_move_blocks(self.single_blocks, i) + + img = img[:, :-max_txt_len, ...] + final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/library/config_util.py b/library/config_util.py index 10b2457f3..53727f252 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,10 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 + resize_interpolation: Optional[str] = None @dataclass @@ -104,12 +102,12 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False - + validation_seed: Optional[int] = None + validation_split: float = 0.0 + resize_interpolation: Optional[str] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -120,8 +118,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -199,6 +196,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, + "resize_interpolation": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -240,8 +239,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "resize_interpolation": str, } # options handled by argparse but not handled by user config @@ -468,118 +470,138 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value - -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) + val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") + continue - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_separator: {subset.caption_separator} - secondary_separator: {subset.secondary_separator} - enable_wildcard: {subset.enable_wildcard} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, - """ - ), - " ", - ) + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: + continue - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset - logger.info(f"{info}") + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) + val_datasets.append(dataset) + + def print_info(_datasets, dataset_type: str): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [{dataset_type} {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + resize_interpolation: {dataset.resize_interpolation} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent(dedent(f"""\ + [Subset {j} of {dataset_type} {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask} + resize_interpolation: {subset.resize_interpolation} + custom_attributes: {subset.custom_attributes} + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + logger.info(info) + + print_info(datasets, "Dataset") + + if len(val_datasets) > 0: + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Prepare validation dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 000000000..0681dcdcb --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,338 @@ +from concurrent.futures import ThreadPoolExecutor +import gc +import time +from typing import Any, Optional, Union, Callable, Tuple +import torch +import torch.nn as nn + + +# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py +def _clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +def _synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.Stream(device="cuda") + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + _synchronize_device(device) + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + _synchronize_device(device) + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s") + + +# Gradient tensors +_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor] + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__( + self, + blocks: Union[list[nn.Module], nn.ModuleList], + blocks_to_swap: int, + device: torch.device, + supports_backward: bool = True, + debug: bool = False, + ): + super().__init__(len(blocks), blocks_to_swap, device, debug) + + self.supports_backward = supports_backward + self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference + + if self.supports_backward: + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def set_forward_only(self, forward_only: bool): + self.forward_only = forward_only + + def __del__(self): + if self.supports_backward: + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook( + self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int + ) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + + def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + if self.debug: + print(f"Prepare block devices before forward") + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device + weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu + + _synchronize_device(self.device) + _clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int): + # check if blocks_to_swap is enabled + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + # if backward is enabled, we do not swap blocks in forward pass more than blocks_to_swap, because it should be on GPU + if not self.forward_only and block_idx >= self.blocks_to_swap: + return + + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + # this works for forward-only offloading. move upstream blocks to cuda + block_idx_to_cuda = block_idx_to_cuda % self.num_blocks + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + +# endregion + +# region cpu offload utils + + +def to_device(x: Any, device: torch.device) -> Any: + if isinstance(x, torch.Tensor): + return x.to(device) + elif isinstance(x, list): + return [to_device(elem, device) for elem in x] + elif isinstance(x, tuple): + return tuple(to_device(elem, device) for elem in x) + elif isinstance(x, dict): + return {k: to_device(v, device) for k, v in x.items()} + else: + return x + + +def to_cpu(x: Any) -> Any: + """ + Recursively moves torch.Tensor objects (and containers thereof) to CPU. + + Args: + x: A torch.Tensor, or a (possibly nested) list, tuple, or dict containing tensors. + + Returns: + The same structure as x, with all torch.Tensor objects moved to CPU. + Non-tensor objects are returned unchanged. + """ + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, list): + return [to_cpu(elem) for elem in x] + elif isinstance(x, tuple): + return tuple(to_cpu(elem) for elem in x) + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +def create_cpu_offloading_wrapper(func: Callable, device: torch.device) -> Callable: + """ + Create a wrapper function that offloads inputs to CPU before calling the original function + and moves outputs back to the specified device. + + Args: + func: The original function to wrap. + device: The device to move outputs back to. + + Returns: + A wrapped function that offloads inputs to CPU and moves outputs back to the specified device. + """ + + def wrapper(orig_func: Callable) -> Callable: + def custom_forward(*inputs): + nonlocal device, orig_func + cuda_inputs = to_device(inputs, device) + outputs = orig_func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return wrapper(func) + + +# endregion diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index faf443048..ad3e69ffb 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,7 +1,9 @@ +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch import argparse import random import re +from torch.types import Number from typing import List, Optional, Union from .utils import setup_logging @@ -63,7 +65,7 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): +def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False return loss -def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): +def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): scale = get_snr_scale(timesteps, noise_scheduler) loss = loss * scale return loss -def get_snr_scale(timesteps, noise_scheduler): +def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) @@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler): return scale -def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): +def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: @@ -453,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 99a7b2b3b..a8a05c3a1 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -5,6 +5,8 @@ from .utils import setup_logging +from .device_utils import get_preferred_device + setup_logging() import logging @@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() + self.models = torch.nn.ModuleDict() + + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) + + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + def __wrap_model_with_torch_autocast(self, model): + if isinstance(model, torch.nn.ModuleList): + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) + else: + model = self.__wrap_model_forward_with_torch_autocast(model) + return model + + def __wrap_model_forward_with_torch_autocast(self, model): + + assert hasattr(model, "forward"), f"model must have a forward method." + + forward_fn = model.forward + + def forward(*args, **kwargs): + try: + device_type = model.device.type + except AttributeError: + logger.warning( + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) + device_type = get_preferred_device().type + + with torch.autocast(device_type = device_type): + return forward_fn(*args, **kwargs) + + model.forward = forward + return model + def get_models(self): return self.models + ds_model = DeepSpeedWrapper(**models) return ds_model diff --git a/library/device_utils.py b/library/device_utils.py index d2e197450..2d59b64be 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -1,7 +1,10 @@ import functools import gc +from typing import Optional, Union import torch + + try: # intel gpu support for pytorch older than 2.5 # ipex is not needed after pytorch 2.5 @@ -36,12 +39,15 @@ def clean_memory(): torch.mps.empty_cache() -def clean_memory_on_device(device: torch.device): +def clean_memory_on_device(device: Optional[Union[str, torch.device]]): r""" Clean memory on the specified device, will be called from training scripts. """ gc.collect() - + if device is None: + return + if isinstance(device, str): + device = torch.device(device) # device may "cuda" or "cuda:0", so we need to check the type of device if device.type == "cuda": torch.cuda.empty_cache() @@ -51,6 +57,19 @@ def clean_memory_on_device(device: torch.device): torch.mps.empty_cache() +def synchronize_device(device: Optional[Union[str, torch.device]]): + if device is None: + return + if isinstance(device, str): + device = torch.device(device) + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + @functools.lru_cache(maxsize=None) def get_preferred_device() -> torch.device: r""" diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 000000000..d2d7e06c7 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,1329 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + + +import math +import os +import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from library import utils +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +from library import custom_offloading_utils + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt blocks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) + + else: + return self._forward(img, txt, vec, pe, txt_attention_mask) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) + else: + return self._forward(x, vec, pe, txt_attention_mask) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + def get_model_type(self) -> str: + return "flux" + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + return None # FLUX.1 does not use mod_vectors, but Chroma does. + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + mod_vectors: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) + + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + + img = torch.cat((txt, img), 1) + for block_idx, block in enumerate(self.single_blocks): + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + img = img[:, txt.shape[1] :, ...] + + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(controlnet_single_depth) + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_single_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)), + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = nn.ModuleList() + self.single_blocks = nn.ModuleList() + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 000000000..06fe0b953 --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,690 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import flux_models, flux_utils, strategy_base, train_util +from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import mem_eff_save_file + +init_ipex() + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, + controlnet=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + if text_encoders is not None: + text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: Optional[List[CLIPTextModel]], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5) + cfg_scale = prompt_dict.get("scale", 1.0) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + if cfg_scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"embedded guidance scale: {emb_guidance_scale}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) + # encode negative prompts + if cfg_scale != 1.0: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) + neg_t5_attn_mask = ( + neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None + ) + neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + else: + neg_cond = None + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(), torch.no_grad(): + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=emb_guidance_scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + neg_cond=neg_cond, + ) + + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, # t5_out + txt_ids: torch.Tensor, + vec: torch.Tensor, # l_pooled + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, + neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + do_cfg = neg_cond is not None + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() + + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None + + if not do_cfg: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + else: + cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond + nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + + # TODO is it ok to use the same block samples for both cond and uncond? + block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0) + block_single_samples = ( + None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0) + ) + + nc_c_pred = model( + img=torch.cat([img, img], dim=0), + img_ids=torch.cat([img_ids, img_ids], dim=0), + txt=torch.cat([neg_t5_out, txt], dim=0), + txt_ids=torch.cat([txt_ids, txt_ids], dim=0), + y=torch.cat([neg_l_pooled, vec], dim=0), + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec.repeat(2), + guidance=guidance_vec.repeat(2), + txt_attention_mask=nc_c_t5_attn_mask, + ) + neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) + pred = neg_pred + (pred - neg_pred) * cfg_scale + + img = img + (t_prev - t_curr) * pred + + model.prepare_block_swap_before_forward() + return img + + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, h, w = latents.shape + assert bsz > 0, "Batch size not large enough" + num_timesteps = noise_scheduler.config.num_train_timesteps + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random sigma-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + sigmas = torch.rand((bsz,), device=device) + + timesteps = sigmas * num_timesteps + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_timesteps + elif args.timestep_sampling == "flux_shift": + sigmas = torch.randn(bsz, device=device) + sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling + sigmas = sigmas.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + sigmas = time_shift(mu, 1.0, sigmas) + timesteps = sigmas * num_timesteps + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * num_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + + # Broadcast sigmas to latent shape + sigmas = sigmas.view(-1, 1, 1, 1) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + else: + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) + + parser.add_argument( + "--model_type", + type=str, + choices=["flux", "chroma"], + default="flux", + help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)", + ) diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 000000000..410b34ce2 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,563 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import flux_models +from library.safetensors_utils import load_safetensors + +MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" +MODEL_VERSION_CHROMA = "chroma" + + +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths + + +def load_flow_model( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + model_type: str = "flux", +) -> Tuple[bool, flux_models.Flux]: + if model_type == "flux": + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return is_schnell, model + + elif model_type == "chroma": + from . import chroma_models + + # build model + logger.info("Building Chroma model") + with torch.device("meta"): + model = chroma_models.Chroma(chroma_models.chroma_params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Chroma: {info}") + is_schnell = False # Chroma is not schnell + return is_schnell, model + + else: + raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") + + +def load_ae( + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_controlnet( + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet + + +def dummy_clip_l() -> torch.nn.Module: + """ + Returns a dummy CLIP-L model with the output shape of (N, 77, 768). + """ + return DummyCLIPL() + + +class DummyTextModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.embeddings = torch.nn.Parameter(torch.zeros(1)) + + +class DummyCLIPL(torch.nn.Module): + def __init__(self): + super().__init__() + self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output + + # dtype and device from these parameters. train_network.py accesses them + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_2 = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_3 = torch.nn.Parameter(torch.zeros(1)) + self.text_model = DummyTextModel() + + @property + def device(self): + return self.dummy_param.device + + @property + def dtype(self): + return self.dummy_param.dtype + + def forward(self, *args, **kwargs): + """ + Returns a dummy output with the shape of (N, 77, 768). + """ + batch_size = args[0].shape[0] if args else 1 + return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)} + + +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_t5xxl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(num_double_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(num_single_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py new file mode 100644 index 000000000..02f99ab6d --- /dev/null +++ b/library/fp8_optimization_utils.py @@ -0,0 +1,469 @@ +import os +from typing import List, Optional, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + +import logging + +from tqdm import tqdm + +from library.device_utils import clean_memory_on_device +from library.safetensors_utils import MemoryEfficientSafeOpen +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): + """ + Calculate the maximum representable value in FP8 format. + Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). Only supports E4M3 and E5M2 with sign bit. + + Args: + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits (0 or 1) + + Returns: + float: Maximum value representable in FP8 format + """ + assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" + if exp_bits == 4 and mantissa_bits == 3 and sign_bits == 1: + return torch.finfo(torch.float8_e4m3fn).max + elif exp_bits == 5 and mantissa_bits == 2 and sign_bits == 1: + return torch.finfo(torch.float8_e5m2).max + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits} with sign_bits={sign_bits}") + + +# The following is a manual calculation method (wrong implementation for E5M2), kept for reference. +""" +# Calculate exponent bias +bias = 2 ** (exp_bits - 1) - 1 + +# Calculate maximum mantissa value +mantissa_max = 1.0 +for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + +# Calculate maximum value +max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + +return max_value +""" + + +def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value): + """ + Quantize a tensor to FP8 format using PyTorch's native FP8 dtype support. + + Args: + tensor (torch.Tensor): Tensor to quantize + scale (float or torch.Tensor): Scale factor + fp8_dtype (torch.dtype): Target FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) + max_value (float): Maximum representable value in FP8 + min_value (float): Minimum representable value in FP8 + + Returns: + torch.Tensor: Quantized tensor in FP8 format + """ + tensor = tensor.to(torch.float32) # ensure tensor is in float32 for division + + # Create scaled tensor + tensor = torch.div(tensor, scale).nan_to_num_(0.0) # handle NaN values, equivalent to nonzero_mask in previous function + + # Clamp tensor to range + tensor = tensor.clamp_(min=min_value, max=max_value) + + # Convert to FP8 dtype + tensor = tensor.to(fp8_dtype) + + return tensor + + +def optimize_state_dict_with_fp8( + state_dict: dict, + calc_device: Union[str, torch.device], + target_layer_keys: Optional[list[str]] = None, + exclude_layer_keys: Optional[list[str]] = None, + exp_bits: int = 4, + mantissa_bits: int = 3, + move_to_device: bool = False, + quantization_mode: str = "block", + block_size: Optional[int] = 64, +): + """ + Optimize Linear layer weights in a model's state dict to FP8 format. The state dict is modified in-place. + This function is a static version of load_safetensors_with_fp8_optimization without loading from files. + + Args: + state_dict (dict): State dict to optimize, replaced in-place + calc_device (str): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Create optimized state dict + optimized_count = 0 + + # Enumerate tarket keys + target_state_dict_keys = [] + for key in state_dict.keys(): + # Check if it's a weight key and matches target patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + is_target = is_target and not is_excluded + + if is_target and isinstance(state_dict[key], torch.Tensor): + target_state_dict_keys.append(key) + + # Process each key + for key in tqdm(target_state_dict_keys): + value = state_dict[key] + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + # keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model. + scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None: # optimized_count % 10 == 0 and + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def quantize_weight( + key: str, + tensor: torch.Tensor, + fp8_dtype: torch.dtype, + max_value: float, + min_value: float, + quantization_mode: str = "block", + block_size: int = 64, +): + original_shape = tensor.shape + + # Determine quantization mode + if quantization_mode == "block": + if tensor.ndim != 2: + quantization_mode = "tensor" # fallback to per-tensor + else: + out_features, in_features = tensor.shape + if in_features % block_size != 0: + quantization_mode = "channel" # fallback to per-channel + logger.warning( + f"Layer {key} with shape {tensor.shape} is not divisible by block_size {block_size}, fallback to per-channel quantization." + ) + else: + num_blocks = in_features // block_size + tensor = tensor.contiguous().view(out_features, num_blocks, block_size) # [out, num_blocks, block_size] + elif quantization_mode == "channel": + if tensor.ndim != 2: + quantization_mode = "tensor" # fallback to per-tensor + + # Calculate scale factor (per-tensor or per-output-channel with percentile or max) + # value shape is expected to be [out_features, in_features] for Linear weights + if quantization_mode == "channel" or quantization_mode == "block": + # row-wise percentile to avoid being dominated by outliers + # result shape: [out_features, 1] or [out_features, num_blocks, 1] + scale_dim = 1 if quantization_mode == "channel" else 2 + abs_w = torch.abs(tensor) + + # shape: [out_features, 1] or [out_features, num_blocks, 1] + row_max = torch.max(abs_w, dim=scale_dim, keepdim=True).values + scale = row_max / max_value + + else: + # per-tensor + tensor_max = torch.max(torch.abs(tensor).view(-1)) + scale = tensor_max / max_value + + # numerical safety + scale = torch.clamp(scale, min=1e-8) + scale = scale.to(torch.float32) # ensure scale is in float32 for division + + # Quantize weight to FP8 (scale can be scalar or [out,1], broadcasting works) + quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value) + + # If block-wise, restore original shape + if quantization_mode == "block": + quantized_weight = quantized_weight.view(original_shape) # restore to original shape [out, in] + + return quantized_weight, scale + + +def load_safetensors_with_fp8_optimization( + model_files: List[str], + calc_device: Union[str, torch.device], + target_layer_keys=None, + exclude_layer_keys=None, + exp_bits=4, + mantissa_bits=3, + move_to_device=False, + weight_hook=None, + quantization_mode: str = "block", + block_size: Optional[int] = 64, +) -> dict: + """ + Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization. + + Args: + model_files (list[str]): List of model files to load + calc_device (str or torch.device): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + weight_hook (callable, optional): Function to apply to each weight tensor before optimization + quantization_mode (str): Quantization mode, "tensor", "channel", or "block" + block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block") + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Define function to determine if a key is a target key. target means fp8 optimization, not for weight hook. + def is_target_key(key): + # Check if weight key matches target patterns and does not match exclude patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + return is_target and not is_excluded + + # Create optimized state dict + optimized_count = 0 + + # Process each file + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + keys = f.keys() + for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"): + value = f.get_tensor(key) + + # Save original device + original_device = value.device # usually cpu + + if weight_hook is not None: + # Apply weight hook if provided + value = weight_hook(key, value, keep_on_calc_device=(calc_device is not None)) + + if not is_target_key(key): + target_device = calc_device if (calc_device is not None and move_to_device) else original_device + value = value.to(target_device) + state_dict[key] = value + continue + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + original_dtype = value.dtype + quantized_weight, scale_tensor = quantize_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + assert fp8_key != scale_key, "FP8 key and scale key must be different" + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + # keep scale shape: [1] or [out,1] or [out, num_blocks, 1]. We can determine the quantization mode from the shape of scale_weight in the patched model. + scale_tensor = scale_tensor.to(dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None and optimized_count % 10 == 0: + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None): + """ + Patched forward method for Linear layers with FP8 weights. + + Args: + self: Linear layer instance + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor. + + Returns: + torch.Tensor: Result of linear transformation + """ + if use_scaled_mm: + # **not tested** + # _scaled_mm only works for per-tensor scale for now (per-channel scale does not work in certain cases) + if self.scale_weight.ndim != 1: + raise ValueError("scaled_mm only supports per-tensor scale_weight for now.") + + input_dtype = x.dtype + original_weight_dtype = self.scale_weight.dtype + target_dtype = self.weight.dtype + # assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)" + + if max_value is None: + # no input quantization + scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + # calculate scale factor for input tensor + scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) + + # quantize input tensor to FP8: this seems to consume a lot of memory + fp8_max_value = torch.finfo(target_dtype).max + fp8_min_value = torch.finfo(target_dtype).min + x = quantize_fp8(x, scale_x, target_dtype, fp8_max_value, fp8_min_value) + + original_shape = x.shape + x = x.reshape(-1, x.shape[-1]).to(target_dtype) + + weight = self.weight.t() + scale_weight = self.scale_weight.to(torch.float32) + + if self.bias is not None: + # float32 is not supported with bias in scaled_mm + o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight) + else: + o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) + + o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1) + return o.to(input_dtype) + + else: + # Dequantize the weight + original_dtype = self.scale_weight.dtype + if self.scale_weight.ndim < 3: + # per-tensor or per-channel quantization, we can broadcast + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + else: + # block-wise quantization, need to reshape weight to match scale shape for broadcasting + out_features, num_blocks, _ = self.scale_weight.shape + dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_features, num_blocks, -1) + dequantized_weight = dequantized_weight * self.scale_weight + dequantized_weight = dequantized_weight.view(self.weight.shape) + + # Perform linear transformation + if self.bias is not None: + output = F.linear(x, dequantized_weight, self.bias) + else: + output = F.linear(x, dequantized_weight) + + return output + + +def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): + """ + Apply monkey patching to a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + + Returns: + nn.Module: The patched model (same instance, modified in-place) + """ + # # Calculate FP8 float8_e5m2 max value + # max_value = calculate_fp8_maxval(5, 2) + max_value = None # do not quantize input tensor + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + scale_shape_info = {} + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + # Store scale shape information + scale_shape_info[module_path] = optimized_state_dict[scale_key].shape + + patched_count = 0 + + # Apply monkey patch to each layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Linear layer with FP8 scale + if isinstance(module, nn.Linear) and has_scale: + # register the scale_weight as a buffer to load the state_dict + # module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + scale_shape = scale_shape_info[name] + module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype)) + + # Create a new forward method with the patched version. + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + logger.info(f"Number of monkey-patched Linear layers: {patched_count}") + return model diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py new file mode 100644 index 000000000..fc320dfc1 --- /dev/null +++ b/library/hunyuan_image_models.py @@ -0,0 +1,489 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate import init_empty_weights + +from library import custom_offloading_utils +from library.attention import AttentionParams +from library.fp8_optimization_utils import apply_fp8_monkey_patch +from library.lora_utils import load_safetensors_with_lora_and_fp8 +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +from library.hunyuan_image_modules import ( + SingleTokenRefiner, + ByT5Mapper, + PatchEmbed2D, + TimestepEmbedder, + MMDoubleStreamBlock, + MMSingleStreamBlock, + FinalLayer, +) +from library.hunyuan_image_utils import get_nd_rotary_pos_embed + +FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] +# FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "_emb"] # , "modulation" +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_emb"] # , "modulation", "_mod" + +# full exclude 24.2GB +# norm and _emb 19.7GB +# fp8 cast 19.7GB + + +# region DiT Model +class HYImageDiffusionTransformer(nn.Module): + """ + HunyuanImage-2.1 Diffusion Transformer. + + A multimodal transformer for image generation with text conditioning, + featuring separate double-stream and single-stream processing blocks. + + Args: + attn_mode: Attention implementation mode ("torch" or "sageattn"). + """ + + def __init__(self, attn_mode: str = "torch", split_attn: bool = False): + super().__init__() + + # Fixed architecture parameters for HunyuanImage-2.1 + self.patch_size = [1, 1] # 1x1 patch size (no spatial downsampling) + self.in_channels = 64 # Input latent channels + self.out_channels = 64 # Output latent channels + self.unpatchify_channels = self.out_channels + self.guidance_embed = False # Guidance embedding disabled + self.rope_dim_list = [64, 64] # RoPE dimensions for 2D positional encoding + self.rope_theta = 256 # RoPE frequency scaling + self.use_attention_mask = True + self.text_projection = "single_refiner" + self.hidden_size = 3584 # Model dimension + self.heads_num = 28 # Number of attention heads + + # Architecture configuration + mm_double_blocks_depth = 20 # Double-stream transformer blocks + mm_single_blocks_depth = 40 # Single-stream transformer blocks + mlp_width_ratio = 4 # MLP expansion ratio + text_states_dim = 3584 # Text encoder output dimension + guidance_embed = False # No guidance embedding + + # Layer configuration + mlp_act_type: str = "gelu_tanh" # MLP activation function + qkv_bias: bool = True # Use bias in QKV projections + qk_norm: bool = True # Apply QK normalization + qk_norm_type: str = "rms" # RMS normalization type + + self.attn_mode = attn_mode + self.split_attn = split_attn + + # ByT5 character-level text encoder mapping + self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False) + + # Image latent patch embedding + self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size) + + # Text token refinement with cross-attention + self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2) + + # Timestep embedding for diffusion process + self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU) + + # MeanFlow not supported in this implementation + self.time_r_in = None + + # Guidance embedding (disabled for non-distilled model) + self.guidance_in = TimestepEmbedder(self.hidden_size, nn.SiLU) if guidance_embed else None + + # Double-stream blocks: separate image and text processing + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + # Single-stream blocks: joint processing of concatenated features + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + ) + for _ in range(mm_single_blocks_depth) + ] + ) + + self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, nn.SiLU) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"HunyuanImage-2.1: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("HunyuanImage-2.1: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward + ) + # , debug=True + print( + f"HunyuanImage-2.1: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def switch_block_swap_for_inference(self): + if self.blocks_to_swap: + self.offloader_double.set_forward_only(True) + self.offloader_single.set_forward_only(True) + self.prepare_block_swap_before_forward() + print(f"HunyuanImage-2.1: Block swap set to forward only.") + + def switch_block_swap_for_training(self): + if self.blocks_to_swap: + self.offloader_double.set_forward_only(False) + self.offloader_single.set_forward_only(False) + self.prepare_block_swap_before_forward() + print(f"HunyuanImage-2.1: Block swap set to forward and backward.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = nn.ModuleList() + self.single_blocks = nn.ModuleList() + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def get_rotary_pos_embed(self, rope_sizes): + """ + Generate 2D rotary position embeddings for image tokens. + + Args: + rope_sizes: Tuple of (height, width) for spatial dimensions. + + Returns: + Tuple of (freqs_cos, freqs_sin) tensors for rotary position encoding. + """ + freqs_cos, freqs_sin = get_nd_rotary_pos_embed(self.rope_dim_list, rope_sizes, theta=self.rope_theta) + return freqs_cos, freqs_sin + + def reorder_txt_token( + self, byt5_txt: torch.Tensor, txt: torch.Tensor, byt5_text_mask: torch.Tensor, text_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Combine and reorder ByT5 character-level and word-level text embeddings. + + Concatenates valid tokens from both encoders and creates appropriate masks. + + Args: + byt5_txt: ByT5 character-level embeddings [B, L1, D]. + txt: Word-level text embeddings [B, L2, D]. + byt5_text_mask: Valid token mask for ByT5 [B, L1]. + text_mask: Valid token mask for word tokens [B, L2]. + + Returns: + Tuple of (reordered_embeddings, combined_mask, sequence_lengths). + """ + # Process each batch element separately to handle variable sequence lengths + + reorder_txt = [] + reorder_mask = [] + + txt_lens = [] + for i in range(text_mask.shape[0]): + byt5_text_mask_i = byt5_text_mask[i].bool() + text_mask_i = text_mask[i].bool() + byt5_text_length = byt5_text_mask_i.sum() + text_length = text_mask_i.sum() + assert byt5_text_length == byt5_text_mask_i[:byt5_text_length].sum() + assert text_length == text_mask_i[:text_length].sum() + + byt5_txt_i = byt5_txt[i] + txt_i = txt[i] + reorder_txt_i = torch.cat( + [byt5_txt_i[:byt5_text_length], txt_i[:text_length], byt5_txt_i[byt5_text_length:], txt_i[text_length:]], dim=0 + ) + + reorder_mask_i = torch.zeros( + byt5_text_mask_i.shape[0] + text_mask_i.shape[0], dtype=torch.bool, device=byt5_text_mask_i.device + ) + reorder_mask_i[: byt5_text_length + text_length] = True + + reorder_txt.append(reorder_txt_i) + reorder_mask.append(reorder_mask_i) + txt_lens.append(byt5_text_length + text_length) + + reorder_txt = torch.stack(reorder_txt) + reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64) + + return reorder_txt, reorder_mask, txt_lens + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + text_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + byt5_text_states: Optional[torch.Tensor] = None, + byt5_text_mask: Optional[torch.Tensor] = None, + rotary_pos_emb_cache: Optional[Dict[Tuple[int, int], Tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Forward pass through the HunyuanImage diffusion transformer. + + Args: + hidden_states: Input image latents [B, C, H, W]. + timestep: Diffusion timestep [B]. + text_states: Word-level text embeddings [B, L, D]. + encoder_attention_mask: Text attention mask [B, L]. + byt5_text_states: ByT5 character-level embeddings [B, L_byt5, D_byt5]. + byt5_text_mask: ByT5 attention mask [B, L_byt5]. + + Returns: + Tuple of (denoised_image, spatial_shape). + """ + img = x = hidden_states + text_mask = encoder_attention_mask + t = timestep + txt = text_states + + # Calculate spatial dimensions for rotary position embeddings + _, _, oh, ow = x.shape + th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling) + if rotary_pos_emb_cache is not None: + if (th, tw) in rotary_pos_emb_cache: + freqs_cis = rotary_pos_emb_cache[(th, tw)] + freqs_cis = (freqs_cis[0].to(img.device), freqs_cis[1].to(img.device)) + else: + freqs_cis = self.get_rotary_pos_embed((th, tw)) + rotary_pos_emb_cache[(th, tw)] = (freqs_cis[0].cpu(), freqs_cis[1].cpu()) + else: + freqs_cis = self.get_rotary_pos_embed((th, tw)) + + # Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C] + img = self.img_in(img) + + # Generate timestep conditioning vector + vec = self.time_in(t) + + # MeanFlow and guidance embedding not used in this configuration + + # Process text tokens through refinement layers + txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask) + txt = self.txt_in(txt, t, txt_attn_params) + + # Integrate character-level ByT5 features with word-level tokens + # Use variable length sequences with sequence lengths + byt5_txt = self.byt5_in(byt5_text_states) + txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask) + + # Trim sequences to maximum length in the batch + img_seq_len = img.shape[1] + max_txt_len = max(txt_lens) + txt = txt[:, :max_txt_len, :] + text_mask = text_mask[:, :max_txt_len] + + attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask) + + input_device = img.device + + # Process through double-stream blocks (separate image/text attention) + for index, block in enumerate(self.double_blocks): + if self.blocks_to_swap: + self.offloader_double.wait_for_block(index) + img, txt = block(img, txt, vec, freqs_cis, attn_params) + if self.blocks_to_swap: + self.offloader_double.submit_move_blocks(self.double_blocks, index) + + # Concatenate image and text tokens for joint processing + x = torch.cat((img, txt), 1) + + # Process through single-stream blocks (joint attention) + for index, block in enumerate(self.single_blocks): + if self.blocks_to_swap: + self.offloader_single.wait_for_block(index) + x = block(x, vec, freqs_cis, attn_params) + if self.blocks_to_swap: + self.offloader_single.submit_move_blocks(self.single_blocks, index) + + x = x.to(input_device) + vec = vec.to(input_device) + + img = x[:, :img_seq_len, ...] + del x + + # Apply final projection to output space + img = self.final_layer(img, vec) + del vec + + # Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W] + img = self.unpatchify_2d(img, th, tw) + return img + + def unpatchify_2d(self, x, h, w): + """ + Convert sequence format back to spatial image format. + + Args: + x: Input tensor [B, H*W, C]. + h: Height dimension. + w: Width dimension. + + Returns: + Spatial tensor [B, C, H, W]. + """ + c = self.unpatchify_channels + + x = x.reshape(shape=(x.shape[0], h, w, c)) + imgs = x.permute(0, 3, 1, 2) + return imgs + + +# endregion + +# region Model Utils + + +def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer: + with init_empty_weights(): + model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn) + if dtype is not None: + model.to(dtype) + return model + + +def load_hunyuan_image_model( + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + split_attn: bool, + loading_device: Union[str, torch.device], + dit_weight_dtype: Optional[torch.dtype], + fp8_scaled: bool = False, + lora_weights_list: Optional[Dict[str, torch.Tensor]] = None, + lora_multipliers: Optional[list[float]] = None, +) -> HYImageDiffusionTransformer: + """ + Load a HunyuanImage model from the specified checkpoint. + + Args: + device (Union[str, torch.device]): Device for optimization or merging + dit_path (str): Path to the DiT model checkpoint. + attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc. + split_attn (bool): Whether to use split attention. + loading_device (Union[str, torch.device]): Device to load the model weights on. + dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights. + If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype. + fp8_scaled (bool): Whether to use fp8 scaling for the model weights. + lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any. + lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any. + """ + # dit_weight_dtype is None for fp8_scaled + assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + + device = torch.device(device) + loading_device = torch.device(loading_device) + + model = create_model(attn_mode, split_attn, dit_weight_dtype) + + # load model weights with dynamic fp8 optimization and LoRA merging if needed + logger.info(f"Loading DiT model from {dit_path}, device={loading_device}") + + sd = load_safetensors_with_lora_and_fp8( + model_files=dit_path, + lora_weights_list=lora_weights_list, + lora_multipliers=lora_multipliers, + fp8_optimization=fp8_scaled, + calc_device=device, + move_to_device=(loading_device == device), + dit_weight_dtype=dit_weight_dtype, + target_keys=FP8_OPTIMIZATION_TARGET_KEYS, + exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS, + ) + + if fp8_scaled: + apply_fp8_monkey_patch(model, sd, use_scaled_mm=False) + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + info = model.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded DiT model from {dit_path}, info={info}") + + return model + + +# endregion diff --git a/library/hunyuan_image_modules.py b/library/hunyuan_image_modules.py new file mode 100644 index 000000000..1953a783e --- /dev/null +++ b/library/hunyuan_image_modules.py @@ -0,0 +1,863 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +from typing import Tuple, Callable +import torch +import torch.nn as nn +from einops import rearrange + +from library import custom_offloading_utils +from library.attention import AttentionParams, attention +from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate +from library.attention import attention + +# region Modules + + +class ByT5Mapper(nn.Module): + """ + Maps ByT5 character-level encoder outputs to transformer hidden space. + + Applies layer normalization, two MLP layers with GELU activation, + and optional residual connection. + + Args: + in_dim: Input dimension from ByT5 encoder (1472 for ByT5-large). + out_dim: Intermediate dimension after first projection. + hidden_dim: Hidden dimension for MLP layer. + out_dim1: Final output dimension matching transformer hidden size. + use_residual: Whether to add residual connection (requires in_dim == out_dim). + """ + + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.fc3 = nn.Linear(out_dim, out_dim1) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + """ + Transform ByT5 embeddings to transformer space. + + Args: + x: Input ByT5 embeddings [..., in_dim]. + + Returns: + Transformed embeddings [..., out_dim1]. + """ + residual = x if self.use_residual else None + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.act_fn(x) + x = self.fc3(x) + if self.use_residual: + x = x + residual + return x + + +class PatchEmbed2D(nn.Module): + """ + 2D patch embedding layer for converting image latents to transformer tokens. + + Uses 2D convolution to project image patches to embedding space. + For HunyuanImage-2.1, patch_size=[1,1] means no spatial downsampling. + + Args: + patch_size: Spatial size of patches (int or tuple). + in_chans: Number of input channels. + embed_dim: Output embedding dimension. + """ + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = tuple(patch_size) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True) + self.norm = nn.Identity() # No normalization layer used + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar diffusion timesteps into vector representations. + + Uses sinusoidal encoding followed by a two-layer MLP. + + Args: + hidden_size: Output embedding dimension. + act_layer: Activation function class (e.g., nn.SiLU). + frequency_embedding_size: Dimension of sinusoidal encoding. + max_period: Maximum period for sinusoidal frequencies. + out_size: Output dimension (defaults to hidden_size). + """ + + def __init__(self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None): + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), nn.Linear(hidden_size, out_size, bias=True) + ) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + return self.mlp(t_freq) + + +class TextProjection(nn.Module): + """ + Projects text embeddings through a two-layer MLP. + + Used for context-aware representation computation in token refinement. + + Args: + in_channels: Input feature dimension. + hidden_size: Hidden and output dimension. + act_layer: Activation function class. + """ + + def __init__(self, in_channels, hidden_size, act_layer): + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class MLP(nn.Module): + """ + Multi-layer perceptron with configurable activation and normalization. + + Standard two-layer MLP with optional dropout and intermediate normalization. + + Args: + in_channels: Input feature dimension. + hidden_channels: Hidden layer dimension (defaults to in_channels). + out_features: Output dimension (defaults to in_channels). + act_layer: Activation function class. + norm_layer: Optional normalization layer class. + bias: Whether to use bias (can be bool or tuple for each layer). + drop: Dropout rate (can be float or tuple for each layer). + use_conv: Whether to use convolution instead of linear (not supported). + """ + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + assert not use_conv, "Convolutional MLP not supported in this implementation." + + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = _to_tuple(bias, 2) + drop_probs = _to_tuple(drop, 2) + + self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_channels, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class IndividualTokenRefinerBlock(nn.Module): + """ + Single transformer block for individual token refinement. + + Applies self-attention and MLP with adaptive layer normalization (AdaLN) + conditioned on timestep and context information. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_drop_rate: MLP dropout rate. + act_type: Activation function (only "silu" supported). + qk_norm: QK normalization flag (must be False). + qk_norm_type: QK normalization type (only "layer" supported). + qkv_bias: Use bias in QKV projections. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + ): + super().__init__() + assert qk_norm_type == "layer", "Only layer normalization supported for QK norm." + assert act_type == "silu", "Only SiLU activation supported." + assert not qk_norm, "QK normalization must be disabled." + + self.heads_num = heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + + self.self_attn_q_norm = nn.Identity() + self.self_attn_k_norm = nn.Identity() + self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=nn.SiLU, drop=mlp_drop_rate) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor: + """ + Apply self-attention and MLP with adaptive conditioning. + + Args: + x: Input token embeddings [B, L, C]. + c: Combined conditioning vector [B, C]. + attn_params: Attention parameters including sequence lengths. + + Returns: + Refined token embeddings [B, L, C]. + """ + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + del norm_x + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + del qkv + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, attn_params=attn_params) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + return x + + +class IndividualTokenRefiner(nn.Module): + """ + Stack of token refinement blocks with self-attention. + + Processes tokens individually with adaptive layer normalization. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + depth: Number of refinement blocks. + mlp_width_ratio: MLP expansion ratio. + mlp_drop_rate: MLP dropout rate. + act_type: Activation function type. + qk_norm: QK normalization flag. + qk_norm_type: QK normalization type. + qkv_bias: Use bias in QKV projections. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + depth: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + ) + for _ in range(depth) + ] + ) + + def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor: + """ + Apply sequential token refinement. + + Args: + x: Input token embeddings [B, L, C]. + c: Combined conditioning vector [B, C]. + attn_params: Attention parameters including sequence lengths. + + Returns: + Refined token embeddings [B, L, C]. + """ + for block in self.blocks: + x = block(x, c, attn_params) + return x + + +class SingleTokenRefiner(nn.Module): + """ + Text embedding refinement with timestep and context conditioning. + + Projects input text embeddings and applies self-attention refinement + conditioned on diffusion timestep and aggregate text context. + + Args: + in_channels: Input text embedding dimension. + hidden_size: Transformer hidden dimension. + heads_num: Number of attention heads. + depth: Number of refinement blocks. + """ + + def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int): + # Fixed architecture parameters for HunyuanImage-2.1 + mlp_drop_rate: float = 0.0 # No MLP dropout + act_type: str = "silu" # SiLU activation + mlp_width_ratio: float = 4.0 # 4x MLP expansion + qk_norm: bool = False # No QK normalization + qk_norm_type: str = "layer" # Layer norm type (unused) + qkv_bias: bool = True # Use QKV bias + + super().__init__() + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) + act_layer = nn.SiLU + self.t_embedder = TimestepEmbedder(hidden_size, act_layer) + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer) + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + ) + + def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor: + """ + Refine text embeddings with timestep conditioning. + + Args: + x: Input text embeddings [B, L, in_channels]. + t: Diffusion timestep [B]. + attn_params: Attention parameters including sequence lengths. + + Returns: + Refined embeddings [B, L, hidden_size]. + """ + timestep_aware_representations = self.t_embedder(t) + + # Compute context-aware representations by averaging valid tokens + txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner + context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C] + + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + del timestep_aware_representations, context_aware_representations + x = self.input_embedder(x) + x = self.individual_token_refiner(x, c, attn_params) + return x + + +class FinalLayer(nn.Module): + """ + Final output projection layer with adaptive layer normalization. + + Projects transformer hidden states to output patch space with + timestep-conditioned modulation. + + Args: + hidden_size: Input hidden dimension. + patch_size: Spatial patch size for output reshaping. + out_channels: Number of output channels. + act_layer: Activation function class. + """ + + def __init__(self, hidden_size, patch_size, out_channels, act_layer): + super().__init__() + + # Layer normalization without learnable parameters + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + out_size = (patch_size[0] * patch_size[1]) * out_channels + self.linear = nn.Linear(hidden_size, out_size, bias=True) + + # Adaptive layer normalization modulation + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + del shift, scale, c + x = self.linear(x) + return x + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization. + + Normalizes input using RMS and applies learnable scaling. + More efficient than LayerNorm as it doesn't compute mean. + + Args: + dim: Input feature dimension. + eps: Small value for numerical stability. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply RMS normalization. + + Args: + x: Input tensor. + + Returns: + RMS normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def reset_parameters(self): + self.weight.fill_(1) + + def forward(self, x): + """ + Apply RMSNorm with learnable scaling. + + Args: + x: Input tensor. + + Returns: + Normalized and scaled tensor. + """ + output = self._norm(x.float()).type_as(x) + del x + # output = output * self.weight + # fp8 support + output = output * self.weight.to(output.dtype) + return output + + +# kept for reference, not used in current implementation +# class LinearWarpforSingle(nn.Module): +# """ +# Linear layer wrapper for concatenating and projecting two inputs. + +# Used in single-stream blocks to combine attention output with MLP features. + +# Args: +# in_dim: Input dimension (sum of both input feature dimensions). +# out_dim: Output dimension. +# bias: Whether to use bias in linear projection. +# """ + +# def __init__(self, in_dim: int, out_dim: int, bias=False): +# super().__init__() +# self.fc = nn.Linear(in_dim, out_dim, bias=bias) + +# def forward(self, x, y): +# """Concatenate inputs along feature dimension and project.""" +# x = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous() +# return self.fc(x) + + +class ModulateDiT(nn.Module): + """ + Timestep conditioning modulation layer. + + Projects timestep embeddings to multiple modulation parameters + for adaptive layer normalization. + + Args: + hidden_size: Input conditioning dimension. + factor: Number of modulation parameters to generate. + act_layer: Activation function class. + """ + + def __init__(self, hidden_size: int, factor: int, act_layer: Callable): + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +class MMDoubleStreamBlock(nn.Module): + """ + Multimodal double-stream transformer block. + + Processes image and text tokens separately with cross-modal attention. + Each stream has its own normalization and MLP layers but shares + attention computation for cross-modal interaction. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_act_type: MLP activation function (only "gelu_tanh" supported). + qk_norm: QK normalization flag (must be True). + qk_norm_type: QK normalization type (only "rms" supported). + qkv_bias: Use bias in QKV projections. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + ): + super().__init__() + + assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported." + assert qk_norm_type == "rms", "Only RMS normalization supported." + assert qk_norm, "QK normalization must be enabled." + + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + # Image stream processing components + self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + + self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6) + self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True) + + # Text stream processing components + self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6) + self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Extract modulation parameters for image and text streams + (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk( + 6, dim=-1 + ) + (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk( + 6, dim=-1 + ) + + # Process image stream for attention + img_modulated = self.img_norm1(img) + img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + del img_mod1_shift, img_mod1_scale + + img_qkv = self.img_attn_qkv(img_modulated) + del img_modulated + img_q, img_k, img_v = img_qkv.chunk(3, dim=-1) + del img_qkv + + img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num) + img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num) + img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + # Apply rotary position embeddings to image tokens + if freqs_cis is not None: + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + del freqs_cis + + # Process text stream for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + + txt_qkv = self.txt_attn_qkv(txt_modulated) + del txt_modulated + txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1) + del txt_qkv + + txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num) + txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num) + txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + # Concatenate image and text tokens for joint attention + img_seq_len = img.shape[1] + q = torch.cat([img_q, txt_q], dim=1) + del img_q, txt_q + k = torch.cat([img_k, txt_k], dim=1) + del img_k, txt_k + v = torch.cat([img_v, txt_v], dim=1) + del img_v, txt_v + + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, attn_params=attn_params) + del qkv + + # Split attention outputs back to separate streams + img_attn, txt_attn = (attn[:, :img_seq_len].contiguous(), attn[:, img_seq_len:].contiguous()) + del attn + + # Apply attention projection and residual connection for image stream + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + del img_attn, img_mod1_gate + + # Apply MLP and residual connection for image stream + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + del img_mod2_shift, img_mod2_scale, img_mod2_gate + + # Apply attention projection and residual connection for text stream + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + del txt_attn, txt_mod1_gate + + # Apply MLP and residual connection for text stream + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + del txt_mod2_shift, txt_mod2_scale, txt_mod2_gate + + return img, txt + + def forward( + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.gradient_checkpointing and self.training: + forward_fn = self._forward + if self.cpu_offload_checkpointing: + forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device) + + return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False) + else: + return self._forward(img, txt, vec, freqs_cis, attn_params) + + +class MMSingleStreamBlock(nn.Module): + """ + Multimodal single-stream transformer block. + + Processes concatenated image and text tokens jointly with shared attention. + Uses parallel linear layers for efficiency and applies RoPE only to image tokens. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_act_type: MLP activation function (only "gelu_tanh" supported). + qk_norm: QK normalization flag (must be True). + qk_norm_type: QK normalization type (only "rms" supported). + qk_scale: Attention scaling factor (computed automatically if None). + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + ): + super().__init__() + + assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported." + assert qk_norm_type == "rms", "Only RMS normalization supported." + assert qk_norm, "QK normalization must be enabled." + + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + self.scale = qk_scale or head_dim**-0.5 + + # Parallel linear projections for efficiency + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) + + # Combined output projection + # self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True) # for reference + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, bias=True) + + # QK normalization layers + self.q_norm = RMSNorm(head_dim, eps=1e-6) + self.k_norm = RMSNorm(head_dim, eps=1e-6) + + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + attn_params: AttentionParams = None, + ) -> torch.Tensor: + # Extract modulation parameters + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + + # Compute Q, K, V, and MLP input + qkv_mlp = self.linear1(x_mod) + del x_mod + q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1) + del qkv_mlp + + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num) + v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Separate image and text tokens + img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :] + del q + img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :] + del k + + # Apply rotary position embeddings only to image tokens + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + del freqs_cis + + # Recombine and compute joint attention + q = torch.cat([img_q, txt_q], dim=1) + del img_q, txt_q + k = torch.cat([img_k, txt_k], dim=1) + del img_k, txt_k + # v = torch.cat([img_v, txt_v], dim=1) + # del img_v, txt_v + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, attn_params=attn_params) + del qkv + + # Combine attention and MLP outputs, apply gating + # output = self.linear2(attn, self.mlp_act(mlp)) + + mlp = self.mlp_act(mlp) + output = torch.cat([attn, mlp], dim=2).contiguous() + del attn, mlp + output = self.linear2(output) + + return x + apply_gate(output, gate=mod_gate) + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + attn_params: AttentionParams = None, + ) -> torch.Tensor: + if self.gradient_checkpointing and self.training: + forward_fn = self._forward + if self.cpu_offload_checkpointing: + forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device) + + return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False) + else: + return self._forward(x, vec, freqs_cis, attn_params) + + +# endregion diff --git a/library/hunyuan_image_text_encoder.py b/library/hunyuan_image_text_encoder.py new file mode 100644 index 000000000..2171b4101 --- /dev/null +++ b/library/hunyuan_image_text_encoder.py @@ -0,0 +1,661 @@ +import json +import re +from typing import Tuple, Optional, Union +import torch +from transformers import ( + AutoTokenizer, + Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, + T5ForConditionalGeneration, + T5Config, + T5Tokenizer, +) +from transformers.models.t5.modeling_t5 import T5Stack +from accelerate import init_empty_weights + +from library.safetensors_utils import load_safetensors +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +BYT5_TOKENIZER_PATH = "google/byt5-small" +QWEN_2_5_VL_IMAGE_ID = "Qwen/Qwen2.5-VL-7B-Instruct" + + +# Copy from Glyph-SDXL-V2 + +COLOR_IDX_JSON = """{"white": 0, "black": 1, "darkslategray": 2, "dimgray": 3, "darkolivegreen": 4, "midnightblue": 5, "saddlebrown": 6, "sienna": 7, "whitesmoke": 8, "darkslateblue": 9, +"indianred": 10, "linen": 11, "maroon": 12, "khaki": 13, "sandybrown": 14, "gray": 15, "gainsboro": 16, "teal": 17, "peru": 18, "gold": 19, +"snow": 20, "firebrick": 21, "crimson": 22, "chocolate": 23, "tomato": 24, "brown": 25, "goldenrod": 26, "antiquewhite": 27, "rosybrown": 28, "steelblue": 29, +"floralwhite": 30, "seashell": 31, "darkgreen": 32, "oldlace": 33, "darkkhaki": 34, "burlywood": 35, "red": 36, "darkgray": 37, "orange": 38, "royalblue": 39, +"seagreen": 40, "lightgray": 41, "tan": 42, "coral": 43, "beige": 44, "palevioletred": 45, "wheat": 46, "lavender": 47, "darkcyan": 48, "slateblue": 49, +"slategray": 50, "orangered": 51, "silver": 52, "olivedrab": 53, "forestgreen": 54, "darkgoldenrod": 55, "ivory": 56, "darkorange": 57, "yellow": 58, "hotpink": 59, +"ghostwhite": 60, "lightcoral": 61, "indigo": 62, "bisque": 63, "darkred": 64, "darksalmon": 65, "lightslategray": 66, "dodgerblue": 67, "lightpink": 68, "mistyrose": 69, +"mediumvioletred": 70, "cadetblue": 71, "deeppink": 72, "salmon": 73, "palegoldenrod": 74, "blanchedalmond": 75, "lightseagreen": 76, "cornflowerblue": 77, "yellowgreen": 78, "greenyellow": 79, +"navajowhite": 80, "papayawhip": 81, "mediumslateblue": 82, "purple": 83, "blueviolet": 84, "pink": 85, "cornsilk": 86, "lightsalmon": 87, "mediumpurple": 88, "moccasin": 89, +"turquoise": 90, "mediumseagreen": 91, "lavenderblush": 92, "mediumblue": 93, "darkseagreen": 94, "mediumturquoise": 95, "paleturquoise": 96, "skyblue": 97, "lemonchiffon": 98, "olive": 99, +"peachpuff": 100, "lightyellow": 101, "lightsteelblue": 102, "mediumorchid": 103, "plum": 104, "darkturquoise": 105, "aliceblue": 106, "mediumaquamarine": 107, "orchid": 108, "powderblue": 109, +"blue": 110, "darkorchid": 111, "violet": 112, "lightskyblue": 113, "lightcyan": 114, "lightgoldenrodyellow": 115, "navy": 116, "thistle": 117, "honeydew": 118, "mintcream": 119, +"lightblue": 120, "darkblue": 121, "darkmagenta": 122, "deepskyblue": 123, "magenta": 124, "limegreen": 125, "darkviolet": 126, "cyan": 127, "palegreen": 128, "aquamarine": 129, +"lawngreen": 130, "lightgreen": 131, "azure": 132, "chartreuse": 133, "green": 134, "mediumspringgreen": 135, "lime": 136, "springgreen": 137}""" + +MULTILINGUAL_10_LANG_IDX_JSON = """{"en-Montserrat-Regular": 0, "en-Poppins-Italic": 1, "en-GlacialIndifference-Regular": 2, "en-OpenSans-ExtraBoldItalic": 3, "en-Montserrat-Bold": 4, "en-Now-Regular": 5, "en-Garet-Regular": 6, "en-LeagueSpartan-Bold": 7, "en-DMSans-Regular": 8, "en-OpenSauceOne-Regular": 9, +"en-OpenSans-ExtraBold": 10, "en-KGPrimaryPenmanship": 11, "en-Anton-Regular": 12, "en-Aileron-BlackItalic": 13, "en-Quicksand-Light": 14, "en-Roboto-BoldItalic": 15, "en-TheSeasons-It": 16, "en-Kollektif": 17, "en-Inter-BoldItalic": 18, "en-Poppins-Medium": 19, +"en-Poppins-Light": 20, "en-RoxboroughCF-RegularItalic": 21, "en-PlayfairDisplay-SemiBold": 22, "en-Agrandir-Italic": 23, "en-Lato-Regular": 24, "en-MoreSugarRegular": 25, "en-CanvaSans-RegularItalic": 26, "en-PublicSans-Italic": 27, "en-CodePro-NormalLC": 28, "en-Belleza-Regular": 29, +"en-JosefinSans-Bold": 30, "en-HKGrotesk-Bold": 31, "en-Telegraf-Medium": 32, "en-BrittanySignatureRegular": 33, "en-Raleway-ExtraBoldItalic": 34, "en-Mont-RegularItalic": 35, "en-Arimo-BoldItalic": 36, "en-Lora-Italic": 37, "en-ArchivoBlack-Regular": 38, "en-Poppins": 39, +"en-Barlow-Black": 40, "en-CormorantGaramond-Bold": 41, "en-LibreBaskerville-Regular": 42, "en-CanvaSchoolFontRegular": 43, "en-BebasNeueBold": 44, "en-LazydogRegular": 45, "en-FredokaOne-Regular": 46, "en-Horizon-Bold": 47, "en-Nourd-Regular": 48, "en-Hatton-Regular": 49, +"en-Nunito-ExtraBoldItalic": 50, "en-CerebriSans-Regular": 51, "en-Montserrat-Light": 52, "en-TenorSans": 53, "en-Norwester-Regular": 54, "en-ClearSans-Bold": 55, "en-Cardo-Regular": 56, "en-Alice-Regular": 57, "en-Oswald-Regular": 58, "en-Gaegu-Bold": 59, +"en-Muli-Black": 60, "en-TAN-PEARL-Regular": 61, "en-CooperHewitt-Book": 62, "en-Agrandir-Grand": 63, "en-BlackMango-Thin": 64, "en-DMSerifDisplay-Regular": 65, "en-Antonio-Bold": 66, "en-Sniglet-Regular": 67, "en-BeVietnam-Regular": 68, "en-NunitoSans10pt-BlackItalic": 69, +"en-AbhayaLibre-ExtraBold": 70, "en-Rubik-Regular": 71, "en-PPNeueMachina-Regular": 72, "en-TAN - MON CHERI-Regular": 73, "en-Jua-Regular": 74, "en-Playlist-Script": 75, "en-SourceSansPro-BoldItalic": 76, "en-MoonTime-Regular": 77, "en-Eczar-ExtraBold": 78, "en-Gatwick-Regular": 79, +"en-MonumentExtended-Regular": 80, "en-BarlowSemiCondensed-Regular": 81, "en-BarlowCondensed-Regular": 82, "en-Alegreya-Regular": 83, "en-DreamAvenue": 84, "en-RobotoCondensed-Italic": 85, "en-BobbyJones-Regular": 86, "en-Garet-ExtraBold": 87, "en-YesevaOne-Regular": 88, "en-Dosis-ExtraBold": 89, +"en-LeagueGothic-Regular": 90, "en-OpenSans-Italic": 91, "en-TANAEGEAN-Regular": 92, "en-Maharlika-Regular": 93, "en-MarykateRegular": 94, "en-Cinzel-Regular": 95, "en-Agrandir-Wide": 96, "en-Chewy-Regular": 97, "en-BodoniFLF-BoldItalic": 98, "en-Nunito-BlackItalic": 99, +"en-LilitaOne": 100, "en-HandyCasualCondensed-Regular": 101, "en-Ovo": 102, "en-Livvic-Regular": 103, "en-Agrandir-Narrow": 104, "en-CrimsonPro-Italic": 105, "en-AnonymousPro-Bold": 106, "en-NF-OneLittleFont-Bold": 107, "en-RedHatDisplay-BoldItalic": 108, "en-CodecPro-Regular": 109, +"en-HalimunRegular": 110, "en-LibreFranklin-Black": 111, "en-TeXGyreTermes-BoldItalic": 112, "en-Shrikhand-Regular": 113, "en-TTNormsPro-Italic": 114, "en-Gagalin-Regular": 115, "en-OpenSans-Bold": 116, "en-GreatVibes-Regular": 117, "en-Breathing": 118, "en-HeroLight-Regular": 119, +"en-KGPrimaryDots": 120, "en-Quicksand-Bold": 121, "en-Brice-ExtraLightSemiExpanded": 122, "en-Lato-BoldItalic": 123, "en-Fraunces9pt-Italic": 124, "en-AbrilFatface-Regular": 125, "en-BerkshireSwash-Regular": 126, "en-Atma-Bold": 127, "en-HolidayRegular": 128, "en-BebasNeueCyrillic": 129, +"en-IntroRust-Base": 130, "en-Gistesy": 131, "en-BDScript-Regular": 132, "en-ApricotsRegular": 133, "en-Prompt-Black": 134, "en-TAN MERINGUE": 135, "en-Sukar Regular": 136, "en-GentySans-Regular": 137, "en-NeueEinstellung-Normal": 138, "en-Garet-Bold": 139, +"en-FiraSans-Black": 140, "en-BantayogLight": 141, "en-NotoSerifDisplay-Black": 142, "en-TTChocolates-Regular": 143, "en-Ubuntu-Regular": 144, "en-Assistant-Bold": 145, "en-ABeeZee-Regular": 146, "en-LexendDeca-Regular": 147, "en-KingredSerif": 148, "en-Radley-Regular": 149, +"en-BrownSugar": 150, "en-MigraItalic-ExtraboldItalic": 151, "en-ChildosArabic-Regular": 152, "en-PeaceSans": 153, "en-LondrinaSolid-Black": 154, "en-SpaceMono-BoldItalic": 155, "en-RobotoMono-Light": 156, "en-CourierPrime-Regular": 157, "en-Alata-Regular": 158, "en-Amsterdam-One": 159, +"en-IreneFlorentina-Regular": 160, "en-CatchyMager": 161, "en-Alta_regular": 162, "en-ArticulatCF-Regular": 163, "en-Raleway-Regular": 164, "en-BrasikaDisplay": 165, "en-TANAngleton-Italic": 166, "en-NotoSerifDisplay-ExtraCondensedItalic": 167, "en-Bryndan Write": 168, "en-TTCommonsPro-It": 169, +"en-AlexBrush-Regular": 170, "en-Antic-Regular": 171, "en-TTHoves-Bold": 172, "en-DroidSerif": 173, "en-AblationRegular": 174, "en-Marcellus-Regular": 175, "en-Sanchez-Italic": 176, "en-JosefinSans": 177, "en-Afrah-Regular": 178, "en-PinyonScript": 179, +"en-TTInterphases-BoldItalic": 180, "en-Yellowtail-Regular": 181, "en-Gliker-Regular": 182, "en-BobbyJonesSoft-Regular": 183, "en-IBMPlexSans": 184, "en-Amsterdam-Three": 185, "en-Amsterdam-FourSlant": 186, "en-TTFors-Regular": 187, "en-Quattrocento": 188, "en-Sifonn-Basic": 189, +"en-AlegreyaSans-Black": 190, "en-Daydream": 191, "en-AristotelicaProTx-Rg": 192, "en-NotoSerif": 193, "en-EBGaramond-Italic": 194, "en-HammersmithOne-Regular": 195, "en-RobotoSlab-Regular": 196, "en-DO-Sans-Regular": 197, "en-KGPrimaryDotsLined": 198, "en-Blinker-Regular": 199, +"en-TAN NIMBUS": 200, "en-Blueberry-Regular": 201, "en-Rosario-Regular": 202, "en-Forum": 203, "en-MistrullyRegular": 204, "en-SourceSerifPro-Regular": 205, "en-Bugaki-Regular": 206, "en-CMUSerif-Roman": 207, "en-GulfsDisplay-NormalItalic": 208, "en-PTSans-Bold": 209, +"en-Sensei-Medium": 210, "en-SquadaOne-Regular": 211, "en-Arapey-Italic": 212, "en-Parisienne-Regular": 213, "en-Aleo-Italic": 214, "en-QuicheDisplay-Italic": 215, "en-RocaOne-It": 216, "en-Funtastic-Regular": 217, "en-PTSerif-BoldItalic": 218, "en-Muller-RegularItalic": 219, +"en-ArgentCF-Regular": 220, "en-Brightwall-Italic": 221, "en-Knewave-Regular": 222, "en-TYSerif-D": 223, "en-Agrandir-Tight": 224, "en-AlfaSlabOne-Regular": 225, "en-TANTangkiwood-Display": 226, "en-Kief-Montaser-Regular": 227, "en-Gotham-Book": 228, "en-JuliusSansOne-Regular": 229, +"en-CocoGothic-Italic": 230, "en-SairaCondensed-Regular": 231, "en-DellaRespira-Regular": 232, "en-Questrial-Regular": 233, "en-BukhariScript-Regular": 234, "en-HelveticaWorld-Bold": 235, "en-TANKINDRED-Display": 236, "en-CinzelDecorative-Regular": 237, "en-Vidaloka-Regular": 238, "en-AlegreyaSansSC-Black": 239, +"en-FeelingPassionate-Regular": 240, "en-QuincyCF-Regular": 241, "en-FiraCode-Regular": 242, "en-Genty-Regular": 243, "en-Nickainley-Normal": 244, "en-RubikOne-Regular": 245, "en-Gidole-Regular": 246, "en-Borsok": 247, "en-Gordita-RegularItalic": 248, "en-Scripter-Regular": 249, +"en-Buffalo-Regular": 250, "en-KleinText-Regular": 251, "en-Creepster-Regular": 252, "en-Arvo-Bold": 253, "en-GabrielSans-NormalItalic": 254, "en-Heebo-Black": 255, "en-LexendExa-Regular": 256, "en-BrixtonSansTC-Regular": 257, "en-GildaDisplay-Regular": 258, "en-ChunkFive-Roman": 259, +"en-Amaranth-BoldItalic": 260, "en-BubbleboddyNeue-Regular": 261, "en-MavenPro-Bold": 262, "en-TTDrugs-Italic": 263, "en-CyGrotesk-KeyRegular": 264, "en-VarelaRound-Regular": 265, "en-Ruda-Black": 266, "en-SafiraMarch": 267, "en-BloggerSans": 268, "en-TANHEADLINE-Regular": 269, +"en-SloopScriptPro-Regular": 270, "en-NeueMontreal-Regular": 271, "en-Schoolbell-Regular": 272, "en-SigherRegular": 273, "en-InriaSerif-Regular": 274, "en-JetBrainsMono-Regular": 275, "en-MADEEvolveSans": 276, "en-Dekko": 277, "en-Handyman-Regular": 278, "en-Aileron-BoldItalic": 279, +"en-Bright-Italic": 280, "en-Solway-Regular": 281, "en-Higuen-Regular": 282, "en-WedgesItalic": 283, "en-TANASHFORD-BOLD": 284, "en-IBMPlexMono": 285, "en-RacingSansOne-Regular": 286, "en-RegularBrush": 287, "en-OpenSans-LightItalic": 288, "en-SpecialElite-Regular": 289, +"en-FuturaLTPro-Medium": 290, "en-MaragsaDisplay": 291, "en-BigShouldersDisplay-Regular": 292, "en-BDSans-Regular": 293, "en-RasputinRegular": 294, "en-Yvesyvesdrawing-BoldItalic": 295, "en-Bitter-Regular": 296, "en-LuckiestGuy-Regular": 297, "en-CanvaSchoolFontDotted": 298, "en-TTFirsNeue-Italic": 299, +"en-Sunday-Regular": 300, "en-HKGothic-MediumItalic": 301, "en-CaveatBrush-Regular": 302, "en-HeliosExt": 303, "en-ArchitectsDaughter-Regular": 304, "en-Angelina": 305, "en-Calistoga-Regular": 306, "en-ArchivoNarrow-Regular": 307, "en-ObjectSans-MediumSlanted": 308, "en-AyrLucidityCondensed-Regular": 309, +"en-Nexa-RegularItalic": 310, "en-Lustria-Regular": 311, "en-Amsterdam-TwoSlant": 312, "en-Virtual-Regular": 313, "en-Brusher-Regular": 314, "en-NF-Lepetitcochon-Regular": 315, "en-TANTWINKLE": 316, "en-LeJour-Serif": 317, "en-Prata-Regular": 318, "en-PPWoodland-Regular": 319, +"en-PlayfairDisplay-BoldItalic": 320, "en-AmaticSC-Regular": 321, "en-Cabin-Regular": 322, "en-Manjari-Bold": 323, "en-MrDafoe-Regular": 324, "en-TTRamillas-Italic": 325, "en-Luckybones-Bold": 326, "en-DarkerGrotesque-Light": 327, "en-BellabooRegular": 328, "en-CormorantSC-Bold": 329, +"en-GochiHand-Regular": 330, "en-Atteron": 331, "en-RocaTwo-Lt": 332, "en-ZCOOLXiaoWei-Regular": 333, "en-TANSONGBIRD": 334, "en-HeadingNow-74Regular": 335, "en-Luthier-BoldItalic": 336, "en-Oregano-Regular": 337, "en-AyrTropikaIsland-Int": 338, "en-Mali-Regular": 339, +"en-DidactGothic-Regular": 340, "en-Lovelace-Regular": 341, "en-BakerieSmooth-Regular": 342, "en-CarterOne": 343, "en-HussarBd": 344, "en-OldStandard-Italic": 345, "en-TAN-ASTORIA-Display": 346, "en-rugratssans-Regular": 347, "en-BMHANNA": 348, "en-BetterSaturday": 349, +"en-AdigianaToybox": 350, "en-Sailors": 351, "en-PlayfairDisplaySC-Italic": 352, "en-Etna-Regular": 353, "en-Revive80Signature": 354, "en-CAGenerated": 355, "en-Poppins-Regular": 356, "en-Jonathan-Regular": 357, "en-Pacifico-Regular": 358, "en-Saira-Black": 359, +"en-Loubag-Regular": 360, "en-Decalotype-Black": 361, "en-Mansalva-Regular": 362, "en-Allura-Regular": 363, "en-ProximaNova-Bold": 364, "en-TANMIGNON-DISPLAY": 365, "en-ArsenicaAntiqua-Regular": 366, "en-BreulGroteskA-RegularItalic": 367, "en-HKModular-Bold": 368, "en-TANNightingale-Regular": 369, +"en-AristotelicaProCndTxt-Rg": 370, "en-Aprila-Regular": 371, "en-Tomorrow-Regular": 372, "en-AngellaWhite": 373, "en-KaushanScript-Regular": 374, "en-NotoSans": 375, "en-LeJour-Script": 376, "en-BrixtonTC-Regular": 377, "en-OleoScript-Regular": 378, "en-Cakerolli-Regular": 379, +"en-Lobster-Regular": 380, "en-FrunchySerif-Regular": 381, "en-PorcelainRegular": 382, "en-AlojaExtended": 383, "en-SergioTrendy-Italic": 384, "en-LovelaceText-Bold": 385, "en-Anaktoria": 386, "en-JimmyScript-Light": 387, "en-IBMPlexSerif": 388, "en-Marta": 389, +"en-Mango-Regular": 390, "en-Overpass-Italic": 391, "en-Hagrid-Regular": 392, "en-ElikaGorica": 393, "en-Amiko-Regular": 394, "en-EFCOBrookshire-Regular": 395, "en-Caladea-Regular": 396, "en-MoonlightBold": 397, "en-Staatliches-Regular": 398, "en-Helios-Bold": 399, +"en-Satisfy-Regular": 400, "en-NexaScript-Regular": 401, "en-Trocchi-Regular": 402, "en-March": 403, "en-IbarraRealNova-Regular": 404, "en-Nectarine-Regular": 405, "en-Overpass-Light": 406, "en-TruetypewriterPolyglOTT": 407, "en-Bangers-Regular": 408, "en-Lazord-BoldExpandedItalic": 409, +"en-Chloe-Regular": 410, "en-BaskervilleDisplayPT-Regular": 411, "en-Bright-Regular": 412, "en-Vollkorn-Regular": 413, "en-Harmattan": 414, "en-SortsMillGoudy-Regular": 415, "en-Biryani-Bold": 416, "en-SugoProDisplay-Italic": 417, "en-Lazord-BoldItalic": 418, "en-Alike-Regular": 419, +"en-PermanentMarker-Regular": 420, "en-Sacramento-Regular": 421, "en-HKGroteskPro-Italic": 422, "en-Aleo-BoldItalic": 423, "en-Noot": 424, "en-TANGARLAND-Regular": 425, "en-Twister": 426, "en-Arsenal-Italic": 427, "en-Bogart-Italic": 428, "en-BethEllen-Regular": 429, +"en-Caveat-Regular": 430, "en-BalsamiqSans-Bold": 431, "en-BreeSerif-Regular": 432, "en-CodecPro-ExtraBold": 433, "en-Pierson-Light": 434, "en-CyGrotesk-WideRegular": 435, "en-Lumios-Marker": 436, "en-Comfortaa-Bold": 437, "en-TraceFontRegular": 438, "en-RTL-AdamScript-Regular": 439, +"en-EastmanGrotesque-Italic": 440, "en-Kalam-Bold": 441, "en-ChauPhilomeneOne-Regular": 442, "en-Coiny-Regular": 443, "en-Lovera": 444, "en-Gellatio": 445, "en-TitilliumWeb-Bold": 446, "en-OilvareBase-Italic": 447, "en-Catamaran-Black": 448, "en-Anteb-Italic": 449, +"en-SueEllenFrancisco": 450, "en-SweetApricot": 451, "en-BrightSunshine": 452, "en-IM_FELL_Double_Pica_Italic": 453, "en-Granaina-limpia": 454, "en-TANPARFAIT": 455, "en-AcherusGrotesque-Regular": 456, "en-AwesomeLathusca-Italic": 457, "en-Signika-Bold": 458, "en-Andasia": 459, +"en-DO-AllCaps-Slanted": 460, "en-Zenaida-Regular": 461, "en-Fahkwang-Regular": 462, "en-Play-Regular": 463, "en-BERNIERRegular-Regular": 464, "en-PlumaThin-Regular": 465, "en-SportsWorld": 466, "en-Garet-Black": 467, "en-CarolloPlayscript-BlackItalic": 468, "en-Cheque-Regular": 469, +"en-SEGO": 470, "en-BobbyJones-Condensed": 471, "en-NexaSlab-RegularItalic": 472, "en-DancingScript-Regular": 473, "en-PaalalabasDisplayWideBETA": 474, "en-Magnolia-Script": 475, "en-OpunMai-400It": 476, "en-MadelynFill-Regular": 477, "en-ZingRust-Base": 478, "en-FingerPaint-Regular": 479, +"en-BostonAngel-Light": 480, "en-Gliker-RegularExpanded": 481, "en-Ahsing": 482, "en-Engagement-Regular": 483, "en-EyesomeScript": 484, "en-LibraSerifModern-Regular": 485, "en-London-Regular": 486, "en-AtkinsonHyperlegible-Regular": 487, "en-StadioNow-TextItalic": 488, "en-Aniyah": 489, +"en-ITCAvantGardePro-Bold": 490, "en-Comica-Regular": 491, "en-Coustard-Regular": 492, "en-Brice-BoldCondensed": 493, "en-TANNEWYORK-Bold": 494, "en-TANBUSTER-Bold": 495, "en-Alatsi-Regular": 496, "en-TYSerif-Book": 497, "en-Jingleberry": 498, "en-Rajdhani-Bold": 499, +"en-LobsterTwo-BoldItalic": 500, "en-BestLight-Medium": 501, "en-Hitchcut-Regular": 502, "en-GermaniaOne-Regular": 503, "en-Emitha-Script": 504, "en-LemonTuesday": 505, "en-Cubao_Free_Regular": 506, "en-MonterchiSerif-Regular": 507, "en-AllertaStencil-Regular": 508, "en-RTL-Sondos-Regular": 509, +"en-HomemadeApple-Regular": 510, "en-CosmicOcto-Medium": 511, "cn-HelloFont-FangHuaTi": 0, "cn-HelloFont-ID-DianFangSong-Bold": 1, "cn-HelloFont-ID-DianFangSong": 2, "cn-HelloFont-ID-DianHei-CEJ": 3, "cn-HelloFont-ID-DianHei-DEJ": 4, "cn-HelloFont-ID-DianHei-EEJ": 5, "cn-HelloFont-ID-DianHei-FEJ": 6, "cn-HelloFont-ID-DianHei-GEJ": 7, "cn-HelloFont-ID-DianKai-Bold": 8, "cn-HelloFont-ID-DianKai": 9, +"cn-HelloFont-WenYiHei": 10, "cn-Hellofont-ID-ChenYanXingKai": 11, "cn-Hellofont-ID-DaZiBao": 12, "cn-Hellofont-ID-DaoCaoRen": 13, "cn-Hellofont-ID-JianSong": 14, "cn-Hellofont-ID-JiangHuZhaoPaiHei": 15, "cn-Hellofont-ID-KeSong": 16, "cn-Hellofont-ID-LeYuanTi": 17, "cn-Hellofont-ID-Pinocchio": 18, "cn-Hellofont-ID-QiMiaoTi": 19, +"cn-Hellofont-ID-QingHuaKai": 20, "cn-Hellofont-ID-QingHuaXingKai": 21, "cn-Hellofont-ID-ShanShuiXingKai": 22, "cn-Hellofont-ID-ShouXieQiShu": 23, "cn-Hellofont-ID-ShouXieTongZhenTi": 24, "cn-Hellofont-ID-TengLingTi": 25, "cn-Hellofont-ID-XiaoLiShu": 26, "cn-Hellofont-ID-XuanZhenSong": 27, "cn-Hellofont-ID-ZhongLingXingKai": 28, "cn-HellofontIDJiaoTangTi": 29, +"cn-HellofontIDJiuZhuTi": 30, "cn-HuXiaoBao-SaoBao": 31, "cn-HuXiaoBo-NanShen": 32, "cn-HuXiaoBo-ZhenShuai": 33, "cn-SourceHanSansSC-Bold": 34, "cn-SourceHanSansSC-ExtraLight": 35, "cn-SourceHanSansSC-Heavy": 36, "cn-SourceHanSansSC-Light": 37, "cn-SourceHanSansSC-Medium": 38, "cn-SourceHanSansSC-Normal": 39, +"cn-SourceHanSansSC-Regular": 40, "cn-SourceHanSerifSC-Bold": 41, "cn-SourceHanSerifSC-ExtraLight": 42, "cn-SourceHanSerifSC-Heavy": 43, "cn-SourceHanSerifSC-Light": 44, "cn-SourceHanSerifSC-Medium": 45, "cn-SourceHanSerifSC-Regular": 46, "cn-SourceHanSerifSC-SemiBold": 47, "cn-xiaowei": 48, "cn-AaJianHaoTi": 49, +"cn-AlibabaPuHuiTi-Bold": 50, "cn-AlibabaPuHuiTi-Heavy": 51, "cn-AlibabaPuHuiTi-Light": 52, "cn-AlibabaPuHuiTi-Medium": 53, "cn-AlibabaPuHuiTi-Regular": 54, "cn-CanvaAcidBoldSC": 55, "cn-CanvaBreezeCN": 56, "cn-CanvaBumperCropSC": 57, "cn-CanvaCakeShopCN": 58, "cn-CanvaEndeavorBlackSC": 59, +"cn-CanvaJoyHeiCN": 60, "cn-CanvaLiCN": 61, "cn-CanvaOrientalBrushCN": 62, "cn-CanvaPoster": 63, "cn-CanvaQinfuCalligraphyCN": 64, "cn-CanvaSweetHeartCN": 65, "cn-CanvaSwordLikeDreamCN": 66, "cn-CanvaTangyuanHandwritingCN": 67, "cn-CanvaWanderWorldCN": 68, "cn-CanvaWenCN": 69, +"cn-DianZiChunYi": 70, "cn-GenSekiGothicTW-H": 71, "cn-GenWanMinTW-L": 72, "cn-GenYoMinTW-B": 73, "cn-GenYoMinTW-EL": 74, "cn-GenYoMinTW-H": 75, "cn-GenYoMinTW-M": 76, "cn-GenYoMinTW-R": 77, "cn-GenYoMinTW-SB": 78, "cn-HYQiHei-AZEJ": 79, +"cn-HYQiHei-EES": 80, "cn-HanaMinA": 81, "cn-HappyZcool-2016": 82, "cn-HelloFont ZJ KeKouKeAiTi": 83, "cn-HelloFont-ID-BoBoTi": 84, "cn-HelloFont-ID-FuGuHei-25": 85, "cn-HelloFont-ID-FuGuHei-35": 86, "cn-HelloFont-ID-FuGuHei-45": 87, "cn-HelloFont-ID-FuGuHei-55": 88, "cn-HelloFont-ID-FuGuHei-65": 89, +"cn-HelloFont-ID-FuGuHei-75": 90, "cn-HelloFont-ID-FuGuHei-85": 91, "cn-HelloFont-ID-HeiKa": 92, "cn-HelloFont-ID-HeiTang": 93, "cn-HelloFont-ID-JianSong-95": 94, "cn-HelloFont-ID-JueJiangHei-50": 95, "cn-HelloFont-ID-JueJiangHei-55": 96, "cn-HelloFont-ID-JueJiangHei-60": 97, "cn-HelloFont-ID-JueJiangHei-65": 98, "cn-HelloFont-ID-JueJiangHei-70": 99, +"cn-HelloFont-ID-JueJiangHei-75": 100, "cn-HelloFont-ID-JueJiangHei-80": 101, "cn-HelloFont-ID-KuHeiTi": 102, "cn-HelloFont-ID-LingDongTi": 103, "cn-HelloFont-ID-LingLiTi": 104, "cn-HelloFont-ID-MuFengTi": 105, "cn-HelloFont-ID-NaiNaiJiangTi": 106, "cn-HelloFont-ID-PangDu": 107, "cn-HelloFont-ID-ReLieTi": 108, "cn-HelloFont-ID-RouRun": 109, +"cn-HelloFont-ID-SaShuangShouXieTi": 110, "cn-HelloFont-ID-WangZheFengFan": 111, "cn-HelloFont-ID-YouQiTi": 112, "cn-Hellofont-ID-XiaLeTi": 113, "cn-Hellofont-ID-XianXiaTi": 114, "cn-HuXiaoBoKuHei": 115, "cn-IDDanMoXingKai": 116, "cn-IDJueJiangHei": 117, "cn-IDMeiLingTi": 118, "cn-IDQQSugar": 119, +"cn-LiuJianMaoCao-Regular": 120, "cn-LongCang-Regular": 121, "cn-MaShanZheng-Regular": 122, "cn-PangMenZhengDao-3": 123, "cn-PangMenZhengDao-Cu": 124, "cn-PangMenZhengDao": 125, "cn-SentyCaramel": 126, "cn-SourceHanSerifSC": 127, "cn-WenCang-Regular": 128, "cn-WenQuanYiMicroHei": 129, +"cn-XianErTi": 130, "cn-YRDZSTJF": 131, "cn-YS-HelloFont-BangBangTi": 132, "cn-ZCOOLKuaiLe-Regular": 133, "cn-ZCOOLQingKeHuangYou-Regular": 134, "cn-ZCOOLXiaoWei-Regular": 135, "cn-ZCOOL_KuHei": 136, "cn-ZhiMangXing-Regular": 137, "cn-baotuxiaobaiti": 138, "cn-jiangxizhuokai-Regular": 139, +"cn-zcool-gdh": 140, "cn-zcoolqingkehuangyouti-Regular": 141, "cn-zcoolwenyiti": 142, "jp-04KanjyukuGothic": 0, "jp-07LightNovelPOP": 1, "jp-07NikumaruFont": 2, "jp-07YasashisaAntique": 3, "jp-07YasashisaGothic": 4, "jp-BokutachinoGothic2Bold": 5, "jp-BokutachinoGothic2Regular": 6, "jp-CHI_SpeedyRight_full_211128-Regular": 7, "jp-CHI_SpeedyRight_italic_full_211127-Regular": 8, "jp-CP-Font": 9, +"jp-Canva_CezanneProN-B": 10, "jp-Canva_CezanneProN-M": 11, "jp-Canva_ChiaroStd-B": 12, "jp-Canva_CometStd-B": 13, "jp-Canva_DotMincho16Std-M": 14, "jp-Canva_GrecoStd-B": 15, "jp-Canva_GrecoStd-M": 16, "jp-Canva_LyraStd-DB": 17, "jp-Canva_MatisseHatsuhiPro-B": 18, "jp-Canva_MatisseHatsuhiPro-M": 19, +"jp-Canva_ModeMinAStd-B": 20, "jp-Canva_NewCezanneProN-B": 21, "jp-Canva_NewCezanneProN-M": 22, "jp-Canva_PearlStd-L": 23, "jp-Canva_RaglanStd-UB": 24, "jp-Canva_RailwayStd-B": 25, "jp-Canva_ReggaeStd-B": 26, "jp-Canva_RocknRollStd-DB": 27, "jp-Canva_RodinCattleyaPro-B": 28, "jp-Canva_RodinCattleyaPro-M": 29, +"jp-Canva_RodinCattleyaPro-UB": 30, "jp-Canva_RodinHimawariPro-B": 31, "jp-Canva_RodinHimawariPro-M": 32, "jp-Canva_RodinMariaPro-B": 33, "jp-Canva_RodinMariaPro-DB": 34, "jp-Canva_RodinProN-M": 35, "jp-Canva_ShadowTLStd-B": 36, "jp-Canva_StickStd-B": 37, "jp-Canva_TsukuAOldMinPr6N-B": 38, "jp-Canva_TsukuAOldMinPr6N-R": 39, +"jp-Canva_UtrilloPro-DB": 40, "jp-Canva_UtrilloPro-M": 41, "jp-Canva_YurukaStd-UB": 42, "jp-FGUIGEN": 43, "jp-GlowSansJ-Condensed-Heavy": 44, "jp-GlowSansJ-Condensed-Light": 45, "jp-GlowSansJ-Normal-Bold": 46, "jp-GlowSansJ-Normal-Light": 47, "jp-HannariMincho": 48, "jp-HarenosoraMincho": 49, +"jp-Jiyucho": 50, "jp-Kaiso-Makina-B": 51, "jp-Kaisotai-Next-UP-B": 52, "jp-KokoroMinchoutai": 53, "jp-Mamelon-3-Hi-Regular": 54, "jp-MotoyaAnemoneStd-W1": 55, "jp-MotoyaAnemoneStd-W5": 56, "jp-MotoyaAnticPro-W3": 57, "jp-MotoyaCedarStd-W3": 58, "jp-MotoyaCedarStd-W5": 59, +"jp-MotoyaGochikaStd-W4": 60, "jp-MotoyaGochikaStd-W8": 61, "jp-MotoyaGothicMiyabiStd-W6": 62, "jp-MotoyaGothicStd-W3": 63, "jp-MotoyaGothicStd-W5": 64, "jp-MotoyaKoinStd-W3": 65, "jp-MotoyaKyotaiStd-W2": 66, "jp-MotoyaKyotaiStd-W4": 67, "jp-MotoyaMaruStd-W3": 68, "jp-MotoyaMaruStd-W5": 69, +"jp-MotoyaMinchoMiyabiStd-W4": 70, "jp-MotoyaMinchoMiyabiStd-W6": 71, "jp-MotoyaMinchoModernStd-W4": 72, "jp-MotoyaMinchoModernStd-W6": 73, "jp-MotoyaMinchoStd-W3": 74, "jp-MotoyaMinchoStd-W5": 75, "jp-MotoyaReisyoStd-W2": 76, "jp-MotoyaReisyoStd-W6": 77, "jp-MotoyaTohitsuStd-W4": 78, "jp-MotoyaTohitsuStd-W6": 79, +"jp-MtySousyokuEmBcJis-W6": 80, "jp-MtySousyokuLiBcJis-W6": 81, "jp-Mushin": 82, "jp-NotoSansJP-Bold": 83, "jp-NotoSansJP-Regular": 84, "jp-NudMotoyaAporoStd-W3": 85, "jp-NudMotoyaAporoStd-W5": 86, "jp-NudMotoyaCedarStd-W3": 87, "jp-NudMotoyaCedarStd-W5": 88, "jp-NudMotoyaMaruStd-W3": 89, +"jp-NudMotoyaMaruStd-W5": 90, "jp-NudMotoyaMinchoStd-W5": 91, "jp-Ounen-mouhitsu": 92, "jp-Ronde-B-Square": 93, "jp-SMotoyaGyosyoStd-W5": 94, "jp-SMotoyaSinkaiStd-W3": 95, "jp-SMotoyaSinkaiStd-W5": 96, "jp-SourceHanSansJP-Bold": 97, "jp-SourceHanSansJP-Regular": 98, "jp-SourceHanSerifJP-Bold": 99, +"jp-SourceHanSerifJP-Regular": 100, "jp-TazuganeGothicStdN-Bold": 101, "jp-TazuganeGothicStdN-Regular": 102, "jp-TelopMinProN-B": 103, "jp-Togalite-Bold": 104, "jp-Togalite-Regular": 105, "jp-TsukuMinPr6N-E": 106, "jp-TsukuMinPr6N-M": 107, "jp-mikachan_o": 108, "jp-nagayama_kai": 109, +"jp-07LogoTypeGothic7": 110, "jp-07TetsubinGothic": 111, "jp-851CHIKARA-DZUYOKU-KANA-A": 112, "jp-ARMinchoJIS-Light": 113, "jp-ARMinchoJIS-Ultra": 114, "jp-ARPCrystalMinchoJIS-Medium": 115, "jp-ARPCrystalRGothicJIS-Medium": 116, "jp-ARShounanShinpitsuGyosyoJIS-Medium": 117, "jp-AozoraMincho-bold": 118, "jp-AozoraMinchoRegular": 119, +"jp-ArialUnicodeMS-Bold": 120, "jp-ArialUnicodeMS": 121, "jp-CanvaBreezeJP": 122, "jp-CanvaLiCN": 123, "jp-CanvaLiJP": 124, "jp-CanvaOrientalBrushCN": 125, "jp-CanvaQinfuCalligraphyJP": 126, "jp-CanvaSweetHeartJP": 127, "jp-CanvaWenJP": 128, "jp-Corporate-Logo-Bold": 129, +"jp-DelaGothicOne-Regular": 130, "jp-GN-Kin-iro_SansSerif": 131, "jp-GN-Koharuiro_Sunray": 132, "jp-GenEiGothicM-B": 133, "jp-GenEiGothicM-R": 134, "jp-GenJyuuGothic-Bold": 135, "jp-GenRyuMinTW-B": 136, "jp-GenRyuMinTW-R": 137, "jp-GenSekiGothicTW-B": 138, "jp-GenSekiGothicTW-R": 139, +"jp-GenSenRoundedTW-B": 140, "jp-GenSenRoundedTW-R": 141, "jp-GenShinGothic-Bold": 142, "jp-GenShinGothic-Normal": 143, "jp-GenWanMinTW-L": 144, "jp-GenYoGothicTW-B": 145, "jp-GenYoGothicTW-R": 146, "jp-GenYoMinTW-B": 147, "jp-GenYoMinTW-R": 148, "jp-HGBouquet": 149, +"jp-HanaMinA": 150, "jp-HanazomeFont": 151, "jp-HinaMincho-Regular": 152, "jp-Honoka-Antique-Maru": 153, "jp-Honoka-Mincho": 154, "jp-HuiFontP": 155, "jp-IPAexMincho": 156, "jp-JK-Gothic-L": 157, "jp-JK-Gothic-M": 158, "jp-JackeyFont": 159, +"jp-KaiseiTokumin-Bold": 160, "jp-KaiseiTokumin-Regular": 161, "jp-Keifont": 162, "jp-KiwiMaru-Regular": 163, "jp-Koku-Mincho-Regular": 164, "jp-MotoyaLMaru-W3-90ms-RKSJ-H": 165, "jp-NewTegomin-Regular": 166, "jp-NicoKaku": 167, "jp-NicoMoji+": 168, "jp-Otsutome_font-Bold": 169, +"jp-PottaOne-Regular": 170, "jp-RampartOne-Regular": 171, "jp-Senobi-Gothic-Bold": 172, "jp-Senobi-Gothic-Regular": 173, "jp-SmartFontUI-Proportional": 174, "jp-SoukouMincho": 175, "jp-TEST_Klee-DB": 176, "jp-TEST_Klee-M": 177, "jp-TEST_UDMincho-B": 178, "jp-TEST_UDMincho-L": 179, +"jp-TT_Akakane-EB": 180, "jp-Tanuki-Permanent-Marker": 181, "jp-TrainOne-Regular": 182, "jp-TsunagiGothic-Black": 183, "jp-Ume-Hy-Gothic": 184, "jp-Ume-P-Mincho": 185, "jp-WenQuanYiMicroHei": 186, "jp-XANO-mincho-U32": 187, "jp-YOzFontM90-Regular": 188, "jp-Yomogi-Regular": 189, +"jp-YujiBoku-Regular": 190, "jp-YujiSyuku-Regular": 191, "jp-ZenKakuGothicNew-Bold": 192, "jp-ZenKakuGothicNew-Regular": 193, "jp-ZenKurenaido-Regular": 194, "jp-ZenMaruGothic-Bold": 195, "jp-ZenMaruGothic-Regular": 196, "jp-darts-font": 197, "jp-irohakakuC-Bold": 198, "jp-irohakakuC-Medium": 199, +"jp-irohakakuC-Regular": 200, "jp-katyou": 201, "jp-mplus-1m-bold": 202, "jp-mplus-1m-regular": 203, "jp-mplus-1p-bold": 204, "jp-mplus-1p-regular": 205, "jp-rounded-mplus-1p-bold": 206, "jp-rounded-mplus-1p-regular": 207, "jp-timemachine-wa": 208, "jp-ttf-GenEiLateMin-Medium": 209, +"jp-uzura_font": 210, "kr-Arita-buri-Bold_OTF": 0, "kr-Arita-buri-HairLine_OTF": 1, "kr-Arita-buri-Light_OTF": 2, "kr-Arita-buri-Medium_OTF": 3, "kr-Arita-buri-SemiBold_OTF": 4, "kr-Canva_YDSunshineL": 5, "kr-Canva_YDSunshineM": 6, "kr-Canva_YoonGulimPro710": 7, "kr-Canva_YoonGulimPro730": 8, "kr-Canva_YoonGulimPro740": 9, +"kr-Canva_YoonGulimPro760": 10, "kr-Canva_YoonGulimPro770": 11, "kr-Canva_YoonGulimPro790": 12, "kr-CreHappB": 13, "kr-CreHappL": 14, "kr-CreHappM": 15, "kr-CreHappS": 16, "kr-OTAuroraB": 17, "kr-OTAuroraL": 18, "kr-OTAuroraR": 19, +"kr-OTDoldamgilB": 20, "kr-OTDoldamgilL": 21, "kr-OTDoldamgilR": 22, "kr-OTHamsterB": 23, "kr-OTHamsterL": 24, "kr-OTHamsterR": 25, "kr-OTHapchangdanB": 26, "kr-OTHapchangdanL": 27, "kr-OTHapchangdanR": 28, "kr-OTSupersizeBkBOX": 29, +"kr-SourceHanSansKR-Bold": 30, "kr-SourceHanSansKR-ExtraLight": 31, "kr-SourceHanSansKR-Heavy": 32, "kr-SourceHanSansKR-Light": 33, "kr-SourceHanSansKR-Medium": 34, "kr-SourceHanSansKR-Normal": 35, "kr-SourceHanSansKR-Regular": 36, "kr-SourceHanSansSC-Bold": 37, "kr-SourceHanSansSC-ExtraLight": 38, "kr-SourceHanSansSC-Heavy": 39, +"kr-SourceHanSansSC-Light": 40, "kr-SourceHanSansSC-Medium": 41, "kr-SourceHanSansSC-Normal": 42, "kr-SourceHanSansSC-Regular": 43, "kr-SourceHanSerifSC-Bold": 44, "kr-SourceHanSerifSC-SemiBold": 45, "kr-TDTDBubbleBubbleOTF": 46, "kr-TDTDConfusionOTF": 47, "kr-TDTDCuteAndCuteOTF": 48, "kr-TDTDEggTakOTF": 49, +"kr-TDTDEmotionalLetterOTF": 50, "kr-TDTDGalapagosOTF": 51, "kr-TDTDHappyHourOTF": 52, "kr-TDTDLatteOTF": 53, "kr-TDTDMoonLightOTF": 54, "kr-TDTDParkForestOTF": 55, "kr-TDTDPencilOTF": 56, "kr-TDTDSmileOTF": 57, "kr-TDTDSproutOTF": 58, "kr-TDTDSunshineOTF": 59, +"kr-TDTDWaferOTF": 60, "kr-777Chyaochyureu": 61, "kr-ArialUnicodeMS-Bold": 62, "kr-ArialUnicodeMS": 63, "kr-BMHANNA": 64, "kr-Baekmuk-Dotum": 65, "kr-BagelFatOne-Regular": 66, "kr-CoreBandi": 67, "kr-CoreBandiFace": 68, "kr-CoreBori": 69, +"kr-DoHyeon-Regular": 70, "kr-Dokdo-Regular": 71, "kr-Gaegu-Bold": 72, "kr-Gaegu-Light": 73, "kr-Gaegu-Regular": 74, "kr-GamjaFlower-Regular": 75, "kr-GasoekOne-Regular": 76, "kr-GothicA1-Black": 77, "kr-GothicA1-Bold": 78, "kr-GothicA1-ExtraBold": 79, +"kr-GothicA1-ExtraLight": 80, "kr-GothicA1-Light": 81, "kr-GothicA1-Medium": 82, "kr-GothicA1-Regular": 83, "kr-GothicA1-SemiBold": 84, "kr-GothicA1-Thin": 85, "kr-Gugi-Regular": 86, "kr-HiMelody-Regular": 87, "kr-Jua-Regular": 88, "kr-KirangHaerang-Regular": 89, +"kr-NanumBrush": 90, "kr-NanumPen": 91, "kr-NanumSquareRoundB": 92, "kr-NanumSquareRoundEB": 93, "kr-NanumSquareRoundL": 94, "kr-NanumSquareRoundR": 95, "kr-SeH-CB": 96, "kr-SeH-CBL": 97, "kr-SeH-CEB": 98, "kr-SeH-CL": 99, +"kr-SeH-CM": 100, "kr-SeN-CB": 101, "kr-SeN-CBL": 102, "kr-SeN-CEB": 103, "kr-SeN-CL": 104, "kr-SeN-CM": 105, "kr-Sunflower-Bold": 106, "kr-Sunflower-Light": 107, "kr-Sunflower-Medium": 108, "kr-TTClaytoyR": 109, +"kr-TTDalpangiR": 110, "kr-TTMamablockR": 111, "kr-TTNauidongmuR": 112, "kr-TTOktapbangR": 113, "kr-UhBeeMiMi": 114, "kr-UhBeeMiMiBold": 115, "kr-UhBeeSe_hyun": 116, "kr-UhBeeSe_hyunBold": 117, "kr-UhBeenamsoyoung": 118, "kr-UhBeenamsoyoungBold": 119, +"kr-WenQuanYiMicroHei": 120, "kr-YeonSung-Regular": 121}""" + + +def add_special_token(tokenizer: T5Tokenizer, text_encoder: T5Stack): + """ + Add special tokens for color and font to tokenizer and text encoder. + + Args: + tokenizer: Huggingface tokenizer. + text_encoder: Huggingface T5 encoder. + """ + idx_font_dict = json.loads(MULTILINGUAL_10_LANG_IDX_JSON) + idx_color_dict = json.loads(COLOR_IDX_JSON) + + font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict] + color_token = [f"" for i in range(len(idx_color_dict))] + additional_special_tokens = [] + additional_special_tokens += color_token + additional_special_tokens += font_token + + tokenizer.add_tokens(additional_special_tokens, special_tokens=True) + # Set mean_resizing=False to avoid PyTorch LAPACK dependency + text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + +def load_byt5( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Tuple[T5Stack, T5Tokenizer]: + BYT5_CONFIG_JSON = """ +{ + "_name_or_path": "/home/patrick/t5/byt5-small", + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "tokenizer_class": "ByT5Tokenizer", + "transformers_version": "4.7.0.dev0", + "use_cache": true, + "vocab_size": 384 + } +""" + + logger.info(f"Loading BYT5 tokenizer from {BYT5_TOKENIZER_PATH}") + byt5_tokenizer = AutoTokenizer.from_pretrained(BYT5_TOKENIZER_PATH) + + logger.info("Initializing BYT5 text encoder") + config = json.loads(BYT5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + byt5_text_encoder = T5ForConditionalGeneration._from_config(config).get_encoder() + + add_special_token(byt5_tokenizer, byt5_text_encoder) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype) + + # remove "encoder." prefix + sd = {k[len("encoder.") :] if k.startswith("encoder.") else k: v for k, v in sd.items()} + sd["embed_tokens.weight"] = sd.pop("shared.weight") + + info = byt5_text_encoder.load_state_dict(sd, strict=True, assign=True) + byt5_text_encoder.to(device) + byt5_text_encoder.eval() + logger.info(f"BYT5 text encoder loaded with info: {info}") + + return byt5_tokenizer, byt5_text_encoder + + +def load_qwen2_5_vl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> tuple[Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration]: + QWEN2_5_VL_CONFIG_JSON = """ +{ + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "text_config": { + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": null, + "initializer_range": 0.02, + "intermediate_size": 18944, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl_text", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": null, + "torch_dtype": "float32", + "use_cache": true, + "use_sliding_window": false, + "video_token_id": null, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.1", + "use_cache": true, + "use_sliding_window": false, + "video_token_id": 151656, + "vision_config": { + "depth": 32, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 3420, + "model_type": "qwen2_5_vl", + "num_heads": 16, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "float32", + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 +} +""" + config = json.loads(QWEN2_5_VL_CONFIG_JSON) + config = Qwen2_5_VLConfig(**config) + with init_empty_weights(): + qwen2_5_vl = Qwen2_5_VLForConditionalGeneration._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype) + + # convert prefixes + for key in list(sd.keys()): + if key.startswith("model."): + new_key = key.replace("model.", "model.language_model.", 1) + elif key.startswith("visual."): + new_key = key.replace("visual.", "model.visual.", 1) + else: + continue + if key not in sd: + logger.warning(f"Key {key} not found in state dict, skipping.") + continue + sd[new_key] = sd.pop(key) + + info = qwen2_5_vl.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded Qwen2.5-VL: {info}") + qwen2_5_vl.to(device) + qwen2_5_vl.eval() + + if dtype is not None: + if dtype.itemsize == 1: # fp8 + org_dtype = torch.bfloat16 # model weight is fp8 in loading, but original dtype is bfloat16 + logger.info(f"prepare Qwen2.5-VL for fp8: set to {dtype} from {org_dtype}") + qwen2_5_vl.to(dtype) + + # prepare LLM for fp8 + def prepare_fp8(vl_model: Qwen2_5_VLForConditionalGeneration, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + # return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + return (module.weight.to(torch.float32) * hidden_states.to(torch.float32)).to(input_dtype) + + return forward + + def decoder_forward_hook(module): + def forward( + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = module.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = module.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + input_dtype = hidden_states.dtype + hidden_states = residual.to(torch.float32) + hidden_states.to(torch.float32) + hidden_states = hidden_states.to(input_dtype) + + # Fully Connected + residual = hidden_states + hidden_states = module.post_attention_layernorm(hidden_states) + hidden_states = module.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + return forward + + for module in vl_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["Qwen2RMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + if module.__class__.__name__ in ["Qwen2_5_VLDecoderLayer"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = decoder_forward_hook(module) + if module.__class__.__name__ in ["Qwen2_5_VisionRotaryEmbedding"]: + # print("set", module.__class__.__name__, "hooks") + module.to(target_dtype) + + prepare_fp8(qwen2_5_vl, org_dtype) + + else: + logger.info(f"Setting Qwen2.5-VL to dtype: {dtype}") + qwen2_5_vl.to(dtype) + + # Load tokenizer + logger.info(f"Loading tokenizer from {QWEN_2_5_VL_IMAGE_ID}") + tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID) + return tokenizer, qwen2_5_vl + + +TOKENIZER_MAX_LENGTH = 1024 +PROMPT_TEMPLATE_ENCODE_START_IDX = 34 + + +def get_qwen_prompt_embeds( + tokenizer: Qwen2Tokenizer, vlm: Qwen2_5_VLForConditionalGeneration, prompt: Union[str, list[str]] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids, mask = get_qwen_tokens(tokenizer, prompt) + return get_qwen_prompt_embeds_from_tokens(vlm, input_ids, mask) + + +def get_qwen_tokens(tokenizer: Qwen2Tokenizer, prompt: Union[str, list[str]] = None) -> Tuple[torch.Tensor, torch.Tensor]: + tokenizer_max_length = TOKENIZER_MAX_LENGTH + + # HunyuanImage-2.1 does not use "<|im_start|>assistant\n" in the prompt template + prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # \n<|im_start|>assistant\n" + prompt_template_encode_start_idx = PROMPT_TEMPLATE_ENCODE_START_IDX + # default_sample_size = 128 + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt") + return txt_tokens.input_ids, txt_tokens.attention_mask + + +def get_qwen_prompt_embeds_from_tokens( + vlm: Qwen2_5_VLForConditionalGeneration, input_ids: torch.Tensor, attention_mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + tokenizer_max_length = TOKENIZER_MAX_LENGTH + drop_idx = PROMPT_TEMPLATE_ENCODE_START_IDX + + device = vlm.device + dtype = vlm.dtype + + input_ids = input_ids.to(device=device) + attention_mask = attention_mask.to(device=device) + + if dtype.itemsize == 1: # fp8 + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + else: + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=dtype, enabled=True): + encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + + hidden_states = encoder_hidden_states.hidden_states[-3] # use the 3rd last layer's hidden states for HunyuanImage-2.1 + if hidden_states.shape[1] > tokenizer_max_length + drop_idx: + logger.warning(f"Hidden states shape {hidden_states.shape} exceeds max length {tokenizer_max_length + drop_idx}") + + # --- Unnecessary complicated processing, keep for reference --- + # split_hidden_states = extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + # split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + # max_seq_len = max([e.size(0) for e in split_hidden_states]) + # prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + # encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + # ---------------------------------------------------------- + + prompt_embeds = hidden_states[:, drop_idx:, :] + encoder_attention_mask = attention_mask[:, drop_idx:] + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def format_prompt(texts, styles): + """ + Text "{text}" in {color}, {type}. + """ + + prompt = "" + for text, style in zip(texts, styles): + # color and style are always None in official implementation, so we only use text + text_prompt = f'Text "{text}"' + text_prompt += ". " + prompt = prompt + text_prompt + return prompt + + +BYT5_MAX_LENGTH = 128 + + +def get_glyph_prompt_embeds( + tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Optional[str] = None +) -> Tuple[list[bool], torch.Tensor, torch.Tensor]: + byt5_tokens, byt5_text_mask = get_byt5_text_tokens(tokenizer, prompt) + return get_byt5_prompt_embeds_from_tokens(text_encoder, byt5_tokens, byt5_text_mask) + + +def get_byt5_prompt_embeds_from_tokens( + text_encoder: T5Stack, byt5_text_ids: Optional[torch.Tensor], byt5_text_mask: Optional[torch.Tensor] +) -> Tuple[list[bool], torch.Tensor, torch.Tensor]: + byt5_max_length = BYT5_MAX_LENGTH + + if byt5_text_ids is None or byt5_text_mask is None or byt5_text_mask.sum() == 0: + return ( + [False], + torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device), + torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64), + ) + + byt5_text_ids = byt5_text_ids.to(device=text_encoder.device) + byt5_text_mask = byt5_text_mask.to(device=text_encoder.device) + + with torch.no_grad(), torch.autocast(device_type=text_encoder.device.type, dtype=text_encoder.dtype, enabled=True): + byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float()) + byt5_emb = byt5_prompt_embeds[0] + + return [True], byt5_emb, byt5_text_mask + + +def get_byt5_text_tokens(tokenizer, prompt): + if not prompt: + return None, None + + try: + text_prompt_texts = [] + # pattern_quote_single = r"\'(.*?)\'" + pattern_quote_double = r"\"(.*?)\"" + pattern_quote_chinese_single = r"‘(.*?)’" + pattern_quote_chinese_double = r"“(.*?)”" + + # matches_quote_single = re.findall(pattern_quote_single, prompt) + matches_quote_double = re.findall(pattern_quote_double, prompt) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt) + + # text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if not text_prompt_texts: + return None, None + + text_prompt_style_list = [{"color": None, "font-family": None} for _ in range(len(text_prompt_texts))] + glyph_text_formatted = format_prompt(text_prompt_texts, text_prompt_style_list) + logger.info(f"Glyph text formatted: {glyph_text_formatted}") + + byt5_text_inputs = tokenizer( + glyph_text_formatted, + padding="max_length", + max_length=BYT5_MAX_LENGTH, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + + byt5_text_ids = byt5_text_inputs.input_ids + byt5_text_mask = byt5_text_inputs.attention_mask + + return byt5_text_ids, byt5_text_mask + + except Exception as e: + logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}") + return None, None diff --git a/library/hunyuan_image_utils.py b/library/hunyuan_image_utils.py new file mode 100644 index 000000000..8e95925ca --- /dev/null +++ b/library/hunyuan_image_utils.py @@ -0,0 +1,525 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +import math +from typing import Tuple, Union, Optional +import torch + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +MODEL_VERSION_2_1 = "hunyuan-image-2.1" + +# region model + + +def _to_tuple(x, dim=2): + """ + Convert int or sequence to tuple of specified dimension. + + Args: + x: Int or sequence to convert. + dim: Target dimension for tuple. + + Returns: + Tuple of length dim. + """ + if isinstance(x, int) or isinstance(x, float): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, dim=2): + """ + Generate n-dimensional coordinate meshgrid from 0 to grid_size. + + Creates coordinate grids for each spatial dimension, useful for + generating position embeddings. + + Args: + start: Grid size for each dimension (int or tuple). + dim: Number of spatial dimensions. + + Returns: + Coordinate grid tensor [dim, *grid_size]. + """ + # Convert start to grid sizes + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + + # Generate coordinate arrays for each dimension + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0): + """ + Generate n-dimensional rotary position embeddings for spatial tokens. + + Creates RoPE embeddings for multi-dimensional positional encoding, + distributing head dimensions across spatial dimensions. + + Args: + rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim). + start: Spatial grid size for each dimension. + theta: Base frequency for RoPE computation. + + Returns: + Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2]. + """ + + grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + # Generate RoPE embeddings for each spatial dimension + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + + +def get_1d_rotary_pos_embed( + dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate 1D rotary position embeddings. + + Args: + dim: Embedding dimension (must be even). + pos: Position indices [S] or scalar for sequence length. + theta: Base frequency for sinusoidal encoding. + + Returns: + Tuple of (cos_freqs, sin_freqs) tensors [S, D]. + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + freqs = torch.outer(pos, freqs) # [S, D/2] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings for diffusion models. + + Converts scalar timesteps to high-dimensional embeddings using + sinusoidal encoding at different frequencies. + + Args: + t: Timestep tensor [N]. + dim: Output embedding dimension. + max_period: Maximum period for frequency computation. + + Returns: + Timestep embeddings [N, dim]. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def modulate(x, shift=None, scale=None): + """ + Apply adaptive layer normalization modulation. + + Applies scale and shift transformations for conditioning + in adaptive layer normalization. + + Args: + x: Input tensor to modulate. + shift: Additive shift parameter (optional). + scale: Multiplicative scale parameter (optional). + + Returns: + Modulated tensor x * (1 + scale) + shift. + """ + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + """ + Apply gating mechanism to tensor. + + Multiplies input by gate values, optionally applying tanh activation. + Used in residual connections for adaptive control. + + Args: + x: Input tensor to gate. + gate: Gating values (optional). + tanh: Whether to apply tanh to gate values. + + Returns: + Gated tensor x * gate (with optional tanh). + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def reshape_for_broadcast( + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + x: torch.Tensor, + head_first=False, +): + """ + Reshape RoPE frequency tensors for broadcasting with attention tensors. + + Args: + freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors. + x: Target tensor for broadcasting compatibility. + head_first: Must be False (only supported layout). + + Returns: + Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting. + """ + assert not head_first, "Only head_first=False layout supported." + assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors." + assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}" + + # Validate frequency tensor dimensions match target tensor + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}" + + shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + +def rotate_half(x): + """ + Rotate half the dimensions for RoPE computation. + + Splits the last dimension in half and applies a 90-degree rotation + by swapping and negating components. + + Args: + x: Input tensor [..., D] where D is even. + + Returns: + Rotated tensor with same shape as input. + """ + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embeddings to query and key tensors. + + Args: + xq: Query tensor [B, S, H, D]. + xk: Key tensor [B, S, H, D]. + freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation. + head_first: Whether head dimension precedes sequence dimension. + + Returns: + Tuple of rotated (query, key) tensors. + """ + device = xq.device + dtype = xq.dtype + + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(device), sin.to(device) + + # Apply rotation: x' = x * cos + rotate_half(x) * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype) + + return xq_out, xk_out + + +# endregion + +# region inference + + +def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate timesteps and sigmas for diffusion sampling. + + Args: + sampling_steps: Number of sampling steps. + shift: Sigma shift parameter for schedule modification. + device: Target device for tensors. + + Returns: + Tuple of (timesteps, sigmas) tensors. + """ + sigmas = torch.linspace(1, 0, sampling_steps + 1) + sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas) + sigmas = sigmas.to(torch.float32) + timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device) + return timesteps, sigmas + + +def step(latents, noise_pred, sigmas, step_i): + """ + Perform a single diffusion sampling step. + + Args: + latents: Current latent state. + noise_pred: Predicted noise. + sigmas: Noise schedule sigmas. + step_i: Current step index. + + Returns: + Updated latents after the step. + """ + return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float() + + +# endregion + + +# region AdaptiveProjectedGuidance + + +class MomentumBuffer: + """ + Exponential moving average buffer for APG momentum. + """ + + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance_apg( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + """ + Apply normalized adaptive projected guidance. + + Projects the guidance vector to reduce over-saturation while maintaining + directional control by decomposing into parallel and orthogonal components. + + Args: + pred_cond: Conditional prediction. + pred_uncond: Unconditional prediction. + guidance_scale: Guidance scale factor. + momentum_buffer: Optional momentum buffer for temporal smoothing. + eta: Scaling factor for parallel component. + norm_threshold: Maximum norm for guidance vector clipping. + use_original_formulation: Whether to use original APG formulation. + + Returns: + Guided prediction tensor. + """ + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch + + # Apply momentum smoothing if available + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + # Apply norm clipping if threshold is set + if norm_threshold > 0: + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm) + diff = diff * scale_factor + + # Project guidance vector into parallel and orthogonal components + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + + # Combine components with different scaling + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred + + +class AdaptiveProjectedGuidance: + """ + Adaptive Projected Guidance for classifier-free guidance. + + Implements APG which projects the guidance vector to reduce over-saturation + while maintaining directional control. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 0.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor: + if step == 0 and self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + + pred = normalized_guidance_apg( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + +def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0): + """ + Rescale guided noise prediction to prevent overexposure and improve image quality. + + This implementation addresses the overexposure issue described in "Common Diffusion Noise + Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4). + The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts. + + Args: + guided_noise (torch.Tensor): Noise prediction from classifier-free guidance. + conditional_noise (torch.Tensor): Noise prediction from conditional model. + rescale_factor (float): Interpolation factor between original and rescaled predictions. + 0.0 = no rescaling, 1.0 = full rescaling. + + Returns: + torch.Tensor: Rescaled noise prediction with reduced overexposure. + """ + if rescale_factor == 0.0: + return guided_noise + + # Calculate standard deviation across spatial dimensions for both predictions + spatial_dims = list(range(1, conditional_noise.ndim)) + conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True) + guided_std = guided_noise.std(dim=spatial_dims, keepdim=True) + + # Rescale guided noise to match conditional noise statistics + std_ratio = conditional_std / guided_std + rescaled_prediction = guided_noise * std_ratio + + # Interpolate between original and rescaled predictions + final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise + + return final_prediction + + +def apply_classifier_free_guidance( + noise_pred_text: torch.Tensor, + noise_pred_uncond: torch.Tensor, + is_ocr: bool, + guidance_scale: float, + step: int, + apg_start_step_ocr: int = 38, + apg_start_step_general: int = 5, + cfg_guider_ocr: AdaptiveProjectedGuidance = None, + cfg_guider_general: AdaptiveProjectedGuidance = None, + guidance_rescale: float = 0.0, +): + """ + Apply classifier-free guidance with OCR-aware APG for batch_size=1. + + Args: + noise_pred_text: Conditional noise prediction tensor [1, ...]. + noise_pred_uncond: Unconditional noise prediction tensor [1, ...]. + is_ocr: Whether this sample requires OCR-specific guidance. + guidance_scale: Guidance scale for CFG. + step: Current diffusion step index. + apg_start_step_ocr: Step to start APG for OCR regions. + apg_start_step_general: Step to start APG for general regions. + cfg_guider_ocr: APG guider for OCR regions. + cfg_guider_general: APG guider for general regions. + + Returns: + Guided noise prediction tensor [1, ...]. + """ + if guidance_scale == 1.0: + return noise_pred_text + + # Select appropriate guider and start step based on OCR requirement + if is_ocr: + cfg_guider = cfg_guider_ocr + apg_start_step = apg_start_step_ocr + else: + cfg_guider = cfg_guider_general + apg_start_step = apg_start_step_general + + # Apply standard CFG or APG based on current step + if step <= apg_start_step: + # Standard classifier-free guidance + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale) + + # Initialize APG guider state + _ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step) + else: + # Use APG for guidance + noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step) + + return noise_pred + + +# endregion diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py new file mode 100644 index 000000000..a6ed1e811 --- /dev/null +++ b/library/hunyuan_image_vae.py @@ -0,0 +1,755 @@ +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import Tensor, nn +from torch.nn import Conv2d +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + +from library.safetensors_utils import load_safetensors +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +VAE_SCALE_FACTOR = 32 # 32x spatial compression + +LATENT_SCALING_FACTOR = 0.75289 # Latent scaling factor for Hunyuan Image-2.1 + + +def swish(x: Tensor) -> Tensor: + """Swish activation function: x * sigmoid(x).""" + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + """Self-attention block using scaled dot-product attention.""" + + def __init__(self, in_channels: int, chunk_size: Optional[int] = None): + super().__init__() + self.in_channels = in_channels + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if chunk_size is None or chunk_size <= 0: + self.q = Conv2d(in_channels, in_channels, kernel_size=1) + self.k = Conv2d(in_channels, in_channels, kernel_size=1) + self.v = Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1) + else: + self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + + def attention(self, x: Tensor) -> Tensor: + x = self.norm(x) + q = self.q(x) + k = self.k(x) + v = self.v(x) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c").contiguous() + k = rearrange(k, "b c h w -> b (h w) c").contiguous() + v = rearrange(v, "b c h w -> b (h w) c").contiguous() + + x = nn.functional.scaled_dot_product_attention(q, k, v) + return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ChunkedConv2d(nn.Conv2d): + """ + Convolutional layer that processes input in chunks to reduce memory usage. + + Parameters + ---------- + chunk_size : int, optional + Size of chunks to process at a time. Default is 64. + """ + + def __init__(self, *args, **kwargs): + if "chunk_size" in kwargs: + self.chunk_size = kwargs.pop("chunk_size", 64) + super().__init__(*args, **kwargs) + assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported." + assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported." + assert self.groups == 1, "Only groups=1 is supported." + assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported." + assert ( + self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2 + ), "Only kernel_size//2 padding is supported." + self.original_padding = self.padding + self.padding = (0, 0) # We handle padding manually in forward + + def forward(self, x: Tensor) -> Tensor: + # If chunking is not needed, process normally. We chunk only along height dimension. + if self.chunk_size is None or x.shape[1] <= self.chunk_size: + self.padding = self.original_padding + x = super().forward(x) + self.padding = (0, 0) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return x + + # Process input in chunks to reduce memory usage + org_shape = x.shape + + # If kernel size is not 1, we need to use overlapping chunks + overlap = self.kernel_size[0] // 2 # 1 for kernel size 3 + step = self.chunk_size - overlap + y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device) + yi = 0 + i = 0 + while i < org_shape[2]: + si = i if i == 0 else i - overlap + ei = i + self.chunk_size + + # Check last chunk. If remaining part is small, include it in last chunk + if ei > org_shape[2] or ei + step // 4 > org_shape[2]: + ei = org_shape[2] + + chunk = x[:, :, : ei - si, :] + x = x[:, :, ei - si - overlap * 2 :, :] + + # Pad chunk if needed: This is as the original Conv2d with padding + if i == 0: # First chunk + # Pad except bottom + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0) + elif ei == org_shape[2]: # Last chunk + # Pad except top + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0) + else: + # Pad left and right only + chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0) + + chunk = super().forward(chunk) + y[:, :, yi : yi + chunk.shape[2], :] = chunk + yi += chunk.shape[2] + del chunk + + if ei == org_shape[2]: + break + i += step + + assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}" + + if torch.cuda.is_available(): + torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit + return y + + +class ResnetBlock(nn.Module): + """ + Residual block with two convolutions, group normalization, and swish activation. + Includes skip connection with optional channel dimension matching. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + if chunk_size is None or chunk_size <= 0: + self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # Skip connection projection for channel dimension mismatch + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + # Skip connection projection for channel dimension mismatch + if self.in_channels != self.out_channels: + self.nin_shortcut = ChunkedConv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size + ) + + def forward(self, x: Tensor) -> Tensor: + h = x + # First convolution block + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + # Second convolution block + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + # Apply skip connection with optional projection + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class Downsample(nn.Module): + """ + Spatial downsampling block that reduces resolution by 2x using convolution followed by + pixel rearrangement. Includes skip connection with grouped averaging. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels (must be divisible by 4). + """ + + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): + super().__init__() + factor = 4 # 2x2 spatial reduction factor + assert out_channels % factor == 0 + + if chunk_size is None or chunk_size <= 0: + self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + else: + self.conv = ChunkedConv2d( + in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size + ) + self.group_size = factor * in_channels // out_channels + + def forward(self, x: Tensor) -> Tensor: + # Apply convolution and rearrange pixels for 2x downsampling + h = self.conv(x) + h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) + + # Create skip connection with pixel rearrangement + shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) + B, C, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2) + + return h + shortcut + + +class Upsample(nn.Module): + """ + Spatial upsampling block that increases resolution by 2x using convolution followed by + pixel rearrangement. Includes skip connection with channel repetition. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): + super().__init__() + factor = 4 # 2x2 spatial expansion factor + + if chunk_size is None or chunk_size <= 0: + self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + else: + self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + self.repeats = factor * out_channels // in_channels + + def forward(self, x: Tensor) -> Tensor: + # Apply convolution and rearrange pixels for 2x upsampling + h = self.conv(x) + h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) + + # Create skip connection with channel repetition + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) + + return h + shortcut + + +class Encoder(nn.Module): + """ + VAE encoder that progressively downsamples input images to a latent representation. + Uses residual blocks, attention, and spatial downsampling. + + Parameters + ---------- + in_channels : int + Number of input image channels (e.g., 3 for RGB). + z_channels : int + Number of latent channels in the output. + block_out_channels : Tuple[int, ...] + Output channels for each downsampling block. + num_res_blocks : int + Number of residual blocks per downsampling stage. + ffactor_spatial : int + Total spatial downsampling factor (e.g., 32 for 32x compression). + """ + + def __init__( + self, + in_channels: int, + z_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + chunk_size: Optional[int] = None, + ): + super().__init__() + assert block_out_channels[-1] % (2 * z_channels) == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + if chunk_size is None or chunk_size <= 0: + self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + else: + self.conv_in = ChunkedConv2d( + in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size + ) + + self.down = nn.ModuleList() + block_in = block_out_channels[0] + + # Build downsampling blocks + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + + # Add residual blocks for this level + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size)) + block_in = block_out + + down = nn.Module() + down.block = block + + # Add spatial downsampling if needed + add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial)) + if add_spatial_downsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] + down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size) + block_in = block_out + + self.down.append(down) + + # Middle blocks with attention + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + if chunk_size is None or chunk_size <= 0: + self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + def forward(self, x: Tensor) -> Tensor: + # Initial convolution + h = self.conv_in(x) + + # Progressive downsampling through blocks + for i_level in range(len(self.block_out_channels)): + # Apply residual blocks at this level + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + # Apply spatial downsampling if available + if hasattr(self.down[i_level], "downsample"): + h = self.down[i_level].downsample(h) + + # Middle processing with attention + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # Final output layers with skip connection + group_size = self.block_out_channels[-1] // (2 * self.z_channels) + shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2) + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h += shortcut + return h + + +class Decoder(nn.Module): + """ + VAE decoder that progressively upsamples latent representations back to images. + Uses residual blocks, attention, and spatial upsampling. + + Parameters + ---------- + z_channels : int + Number of latent channels in the input. + out_channels : int + Number of output image channels (e.g., 3 for RGB). + block_out_channels : Tuple[int, ...] + Output channels for each upsampling block. + num_res_blocks : int + Number of residual blocks per upsampling stage. + ffactor_spatial : int + Total spatial upsampling factor (e.g., 32 for 32x expansion). + """ + + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + chunk_size: Optional[int] = None, + ): + super().__init__() + assert block_out_channels[0] % z_channels == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + block_in = block_out_channels[0] + if chunk_size is None or chunk_size <= 0: + self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + else: + self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + # Middle blocks with attention + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + + # Build upsampling blocks + self.up = nn.ModuleList() + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + + # Add residual blocks for this level (extra block for decoder) + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size)) + block_in = block_out + + up = nn.Module() + up.block = block + + # Add spatial upsampling if needed + add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial)) + if add_spatial_upsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] + up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size) + block_in = block_out + + self.up.append(up) + + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + if chunk_size is None or chunk_size <= 0: + self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + def forward(self, z: Tensor) -> Tensor: + # Initial processing with skip connection + repeats = self.block_out_channels[0] // self.z_channels + h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # Middle processing with attention + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # Progressive upsampling through blocks + for i_level in range(len(self.block_out_channels)): + # Apply residual blocks at this level + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + # Apply spatial upsampling if available + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + + # Final output layers + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class HunyuanVAE2D(nn.Module): + """ + VAE model for Hunyuan Image-2.1 with spatial tiling support. + + This VAE uses a fixed architecture optimized for the Hunyuan Image-2.1 model, + with 32x spatial compression and optional memory-efficient tiling for large images. + """ + + def __init__(self, chunk_size: Optional[int] = None): + super().__init__() + + # Fixed configuration for Hunyuan Image-2.1 + block_out_channels = (128, 256, 512, 512, 1024, 1024) + in_channels = 3 # RGB input + out_channels = 3 # RGB output + latent_channels = 64 + layers_per_block = 2 + ffactor_spatial = 32 # 32x spatial compression + sample_size = 384 # Minimum sample size for tiling + scaling_factor = LATENT_SCALING_FACTOR # 0.75289 # Latent scaling factor + + self.ffactor_spatial = ffactor_spatial + self.scaling_factor = scaling_factor + + self.encoder = Encoder( + in_channels=in_channels, + z_channels=latent_channels, + block_out_channels=block_out_channels, + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + chunk_size=chunk_size, + ) + + self.decoder = Decoder( + z_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + chunk_size=chunk_size, + ) + + # Spatial tiling configuration for memory efficiency + self.use_spatial_tiling = False + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // ffactor_spatial + self.tile_overlap_factor = 0.25 # 25% overlap between tiles + + @property + def dtype(self): + """Get the data type of the model parameters.""" + return next(self.encoder.parameters()).dtype + + @property + def device(self): + """Get the device of the model parameters.""" + return next(self.encoder.parameters()).device + + def enable_spatial_tiling(self, use_tiling: bool = True): + """Enable or disable spatial tiling.""" + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + """Disable spatial tiling.""" + self.use_spatial_tiling = False + + def enable_tiling(self, use_tiling: bool = True): + """Enable or disable spatial tiling (alias for enable_spatial_tiling).""" + self.enable_spatial_tiling(use_tiling) + + def disable_tiling(self): + """Disable spatial tiling (alias for disable_spatial_tiling).""" + self.disable_spatial_tiling() + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """ + Blend two tensors horizontally with smooth transition. + + Parameters + ---------- + a : torch.Tensor + Left tensor. + b : torch.Tensor + Right tensor. + blend_extent : int + Number of columns to blend. + """ + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """ + Blend two tensors vertically with smooth transition. + + Parameters + ---------- + a : torch.Tensor + Top tensor. + b : torch.Tensor + Bottom tensor. + blend_extent : int + Number of rows to blend. + """ + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode large images using spatial tiling to reduce memory usage. + Tiles are processed independently and blended at boundaries. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, T, H, W) or (B, C, H, W). + """ + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = x.ndim + if original_ndim == 5: + x = x.squeeze(2) + + B, C, H, W = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + return moments + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode large latents using spatial tiling to reduce memory usage. + Tiles are processed independently and blended at boundaries. + + Parameters + ---------- + z : torch.Tensor + Latent tensor of shape (B, C, H, W). + """ + B, C, H, W = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + return dec + + def encode(self, x: Tensor) -> DiagonalGaussianDistribution: + """ + Encode input images to latent representation. + Uses spatial tiling for large images if enabled. + + Parameters + ---------- + x : Tensor + Input image tensor of shape (B, C, H, W) or (B, C, T, H, W). + + Returns + ------- + DiagonalGaussianDistribution + Latent distribution with mean and logvar. + """ + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = x.ndim + if original_ndim == 5: + x = x.squeeze(2) + + # Use tiling for large images to reduce memory usage + if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + h = self.spatial_tiled_encode(x) + else: + h = self.encoder(x) + + # Restore time dimension if input was 5D + if original_ndim == 5: + h = h.unsqueeze(2) + + posterior = DiagonalGaussianDistribution(h) + return posterior + + def decode(self, z: Tensor): + """ + Decode latent representation back to images. + Uses spatial tiling for large latents if enabled. + + Parameters + ---------- + z : Tensor + Latent tensor of shape (B, C, H, W) or (B, C, T, H, W). + + Returns + ------- + Tensor + Decoded image tensor. + """ + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = z.ndim + if original_ndim == 5: + z = z.squeeze(2) + + # Use tiling for large latents to reduce memory usage + if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(z) + else: + decoded = self.decoder(z) + + # Restore time dimension if input was 5D + if original_ndim == 5: + decoded = decoded.unsqueeze(2) + + return decoded + + +def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D: + logger.info(f"Initializing VAE with chunk_size={chunk_size}") + vae = HunyuanVAE2D(chunk_size=chunk_size) + + logger.info(f"Loading VAE from {vae_path}") + state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap) + info = vae.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded VAE: {info}") + + vae.to(device) + return vae diff --git a/library/lora_utils.py b/library/lora_utils.py new file mode 100644 index 000000000..6f0fc2285 --- /dev/null +++ b/library/lora_utils.py @@ -0,0 +1,246 @@ +import os +import re +from typing import Dict, List, Optional, Union +import torch +from tqdm import tqdm +from library.device_utils import synchronize_device +from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization +from library.safetensors_utils import MemoryEfficientSafeOpen +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def filter_lora_state_dict( + weights_sd: Dict[str, torch.Tensor], + include_pattern: Optional[str] = None, + exclude_pattern: Optional[str] = None, +) -> Dict[str, torch.Tensor]: + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_pattern is not None: + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + + if exclude_pattern is not None: + original_key_count_ex = len(weights_sd.keys()) + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") + + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + return weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: Union[str, List[str]], + lora_weights_list: Optional[Dict[str, torch.Tensor]], + lora_multipliers: Optional[List[float]], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + + Args: + model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. + lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load. + lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. + fp8_optimization (bool): Whether to apply FP8 optimization. + calc_device (torch.device): Device to calculate on. + move_to_device (bool): Whether to move tensors to the calculation device after loading. + target_keys (Optional[List[str]]): Keys to target for optimization. + exclude_keys (Optional[List[str]]): Keys to exclude from optimization. + """ + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + if isinstance(model_files, str): + model_files = [model_files] + + extended_model_files = [] + for model_file in model_files: + basename = os.path.basename(model_file) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + state_dict = {} + for i in range(count): + filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors" + filepath = os.path.join(os.path.dirname(model_file), filename) + if os.path.exists(filepath): + extended_model_files.append(filepath) + else: + raise FileNotFoundError(f"File {filepath} not found") + else: + extended_model_files.append(model_file) + model_files = extended_model_files + logger.info(f"Loading model files: {model_files}") + + # load LoRA weights + weight_hook = None + if lora_weights_list is None or len(lora_weights_list) == 0: + lora_weights_list = [] + lora_multipliers = [] + list_of_lora_weight_keys = [] + else: + list_of_lora_weight_keys = [] + for lora_sd in lora_weights_list: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + if lora_multipliers is None: + lora_multipliers = [1.0] * len(lora_weights_list) + while len(lora_multipliers) < len(lora_weights_list): + lora_multipliers.append(1.0) + if len(lora_multipliers) > len(lora_weights_list): + lora_multipliers = lora_multipliers[: len(lora_weights_list)] + + # Merge LoRA weights into the state dict + logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") + + # make hook for LoRA merging + def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False): + nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != calc_device: + model_weight = model_weight.to(calc_device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): + # check if this weight has LoRA weights + lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + lora_name = "lora_unet_" + lora_name.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key not in lora_weight_keys or up_key not in lora_weight_keys: + continue + + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(calc_device) + up_weight = up_weight.to(calc_device) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + model_weight = model_weight + multiplier * conved * scale + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + + if not keep_on_calc_device and original_device != calc_device: + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + weight_hook = weight_hook_func + + state_dict = load_safetensors_with_fp8_optimization_and_hook( + model_files, + fp8_optimization, + calc_device, + move_to_device, + dit_weight_dtype, + target_keys, + exclude_keys, + weight_hook=weight_hook, + ) + + for lora_weight_keys in list_of_lora_weight_keys: + # check if all LoRA keys are used + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") + + return state_dict + + +def load_safetensors_with_fp8_optimization_and_hook( + model_files: list[str], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + weight_hook: callable = None, +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + if fp8_optimization: + logger.info( + f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + # dit_weight_dtype is not used because we use fp8 optimization + state_dict = load_safetensors_with_fp8_optimization( + model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook + ) + else: + logger.info( + f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): + if weight_hook is None and move_to_device: + value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) + else: + value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer + if weight_hook is not None: + value = weight_hook(key, value, keep_on_calc_device=move_to_device) + if move_to_device: + value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) + elif dit_weight_dtype is not None: + value = value.to(dit_weight_dtype) + + state_dict[key] = value + if move_to_device: + synchronize_device(calc_device) + + return state_dict diff --git a/library/lumina_models.py b/library/lumina_models.py new file mode 100644 index 000000000..7e9253525 --- /dev/null +++ b/library/lumina_models.py @@ -0,0 +1,1392 @@ +# Copyright Alpha VLLM/Lumina Image 2.0 and contributors +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math +from typing import List, Optional, Tuple +from dataclasses import dataclass + +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import torch.nn as nn +import torch.nn.functional as F + +from library import custom_offloading_utils + +try: + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + # flash_attn may not be available but it is not required + pass + +try: + from sageattention import sageattn +except: + pass + +try: + from apex.normalization import FusedRMSNorm as RMSNorm +except: + import warnings + + warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + + ############################################################################# + # RMSNorm # + ############################################################################# + + class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x) -> Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor): + """ + Apply RMSNorm to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + """ + x_dtype = x.dtype + # To handle float8 we need to convert the tensor to float + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return ((x * rrms) * self.weight.float()).to(dtype=x_dtype) + + + +@dataclass +class LuminaParams: + """Parameters for Lumina model configuration""" + + patch_size: int = 2 + in_channels: int = 4 + dim: int = 4096 + n_layers: int = 30 + n_refiner_layers: int = 2 + n_heads: int = 24 + n_kv_heads: int = 8 + multiple_of: int = 256 + axes_dims: List[int] = None + axes_lens: List[int] = None + qk_norm: bool = False + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + scaling_factor: float = 1.0 + cap_feat_dim: int = 32 + + def __post_init__(self): + if self.axes_dims is None: + self.axes_dims = [36, 36, 36] + if self.axes_lens is None: + self.axes_lens = [300, 512, 512] + + @classmethod + def get_2b_config(cls) -> "LuminaParams": + """Returns the configuration for the 2B parameter model""" + return cls( + patch_size=2, + in_channels=16, # VAE channels + dim=2304, + n_layers=26, + n_heads=24, + n_kv_heads=8, + axes_dims=[32, 32, 32], + axes_lens=[300, 512, 512], + qk_norm=True, + cap_feat_dim=2304, # Gemma 2 hidden_size + ) + + @classmethod + def get_7b_config(cls) -> "LuminaParams": + """Returns the configuration for the 7B parameter model""" + return cls( + patch_size=2, + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[64, 64, 64], + axes_lens=[300, 512, 512], + ) + + +class GradientCheckpointMixin(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = False + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + + +############################################################################# +# Embedding Layers for Timesteps and Class Labels # +############################################################################# + + +class TimestepEmbedder(GradientCheckpointMixin): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + ), + nn.SiLU(), + nn.Linear( + hidden_size, + hidden_size, + bias=True, + ), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.zeros_(self.mlp[0].bias) + nn.init.normal_(self.mlp[2].weight, std=0.02) + nn.init.zeros_(self.mlp[2].bias) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def _forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +############################################################################# +# Core NextDiT Model # +############################################################################# + + +class JointAttention(nn.Module): + """Multi-head attention module.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + use_flash_attn=False, + use_sage_attn=False, + ): + """ + Initialize the Attention module. + + Args: + dim (int): Number of input dimensions. + n_heads (int): Number of heads. + n_kv_heads (Optional[int]): Number of kv heads, if using GQA. + qk_norm (bool): Whether to use normalization for queries and keys. + + """ + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = nn.Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.qkv.weight) + + self.out = nn.Linear( + n_heads * self.head_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.out.weight) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = self.k_norm = nn.Identity() + + self.use_flash_attn = use_flash_attn + self.use_sage_attn = use_sage_attn + + if use_sage_attn : + self.attention_processor = self.sage_attn + else: + # self.attention_processor = xformers.ops.memory_efficient_attention + self.attention_processor = F.scaled_dot_product_attention + + def set_attention_processor(self, attention_processor): + self.attention_processor = attention_processor + + def get_attention_processor(self): + return self.attention_processor + + def forward( + self, + x: Tensor, + x_mask: Tensor, + freqs_cis: Tensor, + ) -> Tensor: + """ + Args: + x: + x_mask: + freqs_cis: + """ + bsz, seqlen, _ = x.shape + dtype = x.dtype + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = self.q_norm(xq) + xk = self.k_norm(xk) + xq = apply_rope(xq, freqs_cis=freqs_cis) + xk = apply_rope(xk, freqs_cis=freqs_cis) + xq, xk = xq.to(dtype), xk.to(dtype) + + softmax_scale = math.sqrt(1 / self.head_dim) + + if self.use_sage_attn: + # Handle GQA (Grouped Query Attention) if needed + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale) + elif self.use_flash_attn: + output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale) + else: + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = ( + self.attention_processor( + xq.permute(0, 2, 1, 3), + xk.permute(0, 2, 1, 3), + xv.permute(0, 2, 1, 3), + attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), + scale=softmax_scale, + ) + .permute(0, 2, 1, 3) + .to(dtype) + ) + + output = output.flatten(-2) + return self.out(output) + + # copied from huggingface modeling_llama.py + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float): + try: + bsz = q.shape[0] + seqlen = q.shape[1] + + # Transpose tensors to match SageAttention's expected format (HND layout) + q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim] + + # Handle masking for SageAttention + # We need to filter out masked positions - this approach handles variable sequence lengths + outputs = [] + for b in range(bsz): + # Find valid token positions from the mask + valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1) + if valid_indices.numel() == 0: + # If all tokens are masked, create a zero output + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + else: + # Extract only valid tokens for this batch + batch_q = q_transposed[b, :, valid_indices, :] + batch_k = k_transposed[b, :, valid_indices, :] + batch_v = v_transposed[b, :, valid_indices, :] + + # Run SageAttention on valid tokens only + batch_output_valid = sageattn( + batch_q.unsqueeze(0), # Add batch dimension back + batch_k.unsqueeze(0), + batch_v.unsqueeze(0), + tensor_layout="HND", + is_causal=False, + sm_scale=softmax_scale + ) + + # Create output tensor with zeros for masked positions + batch_output = torch.zeros( + seqlen, self.n_local_heads, self.head_dim, + device=q.device, dtype=q.dtype + ) + # Place valid outputs back in the right positions + batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2) + + outputs.append(batch_output) + + # Stack batch outputs and reshape to expected format + output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim] + except NameError as e: + raise RuntimeError( + f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}" + ) + + return output + + def flash_attn( + self, + q: Tensor, + k: Tensor, + v: Tensor, + x_mask: Tensor, + softmax_scale, + ) -> Tensor: + bsz, seqlen, _, _ = q.shape + + try: + # begin var_len flash attn + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(q, k, v, x_mask, seqlen) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) + # end var_len_flash_attn + + return output + except NameError as e: + raise RuntimeError( + f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" + ) + + +def apply_rope( + x_in: torch.Tensor, + freqs_cis: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency + tensor. + + This function applies rotary embeddings to the given query 'xq' and + key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The + input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors + contain rotary embeddings and are returned as real tensors. + + Args: + x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex + exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor + and key tensor with rotary embeddings. + """ + with torch.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + + return x_out.type_as(x_in) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden + dimension. Defaults to None. + + """ + super().__init__() + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w1.weight) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w2.weight) + self.w3 = nn.Linear( + dim, + hidden_dim, + bias=False, + ) + nn.init.xavier_uniform_(self.w3.weight) + + # @torch.compile + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class JointTransformerBlock(GradientCheckpointMixin): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + multiple_of: int, + ffn_dim_multiplier: Optional[float], + norm_eps: float, + qk_norm: bool, + modulation=True, + use_flash_attn=False, + use_sage_attn=False, + ) -> None: + """ + Initialize a TransformerBlock. + + Args: + layer_id (int): Identifier for the layer. + dim (int): Embedding dimension of the input features. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): Number of multiple of the hidden dimension. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use normalization for queries and keys. + modulation (bool): Whether to use modulation for the attention + layer. + """ + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(dim, 1024), + 4 * dim, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def _forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + pe: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (Tensor): Input tensor. + pe (Tensor): Rope position embedding. + + Returns: + Tensor: Output tensor after applying attention and + feedforward layers. + + """ + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + pe, + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + else: + assert adaln_input is None + x = x + self.attention_norm2( + self.attention( + self.attention_norm1(x), + x_mask, + pe, + ) + ) + x = x + self.ffn_norm2( + self.feed_forward( + self.ffn_norm1(x), + ) + ) + return x + + +class FinalLayer(GradientCheckpointMixin): + """ + The final layer of NextDiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + """ + Initialize the FinalLayer. + + Args: + hidden_size (int): Hidden size of the input features. + patch_size (int): Patch size of the input features. + out_channels (int): Number of output channels. + """ + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + ) + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + min(hidden_size, 1024), + hidden_size, + bias=True, + ), + ) + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 10000.0, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + ): + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + + def __call__(self, ids: torch.Tensor): + device = ids.device + self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] + result = [] + for i in range(len(self.axes_dims)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + +class NextDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = [16, 56, 56], + axes_lens: List[int] = [1, 512, 512], + use_flash_attn=False, + use_sage_attn=False, + ) -> None: + """ + Initialize the NextDiT model. + + Args: + patch_size (int): Patch size of the input features. + in_channels (int): Number of input channels. + dim (int): Hidden size of the input features. + n_layers (int): Number of Transformer layers. + n_refiner_layers (int): Number of refiner layers. + n_heads (int): Number of attention heads. + n_kv_heads (Optional[int]): Number of attention heads in key and + value features (if using GQA), or set to None for the same as + query. + multiple_of (int): Multiple of the hidden size. + ffn_dim_multiplier (Optional[float]): Dimension multiplier for the + feedforward layer. + norm_eps (float): Epsilon value for normalization. + qk_norm (bool): Whether to use query key normalization. + cap_feat_dim (int): Dimension of the caption features. + axes_dims (List[int]): List of dimensions for the axes. + axes_lens (List[int]): List of lengths for the axes. + use_flash_attn (bool): Whether to use Flash Attention. + use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference. + + Returns: + None + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + + self.t_embedder = TimestepEmbedder(min(dim, 1024)) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + dim, + bias=True, + ), + ) + + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + nn.init.zeros_(self.cap_embedder[1].bias) + + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + ) + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + use_flash_attn=use_flash_attn, + use_sage_attn=use_sage_attn, + ) + for layer_id in range(n_layers) + ] + ) + self.norm_final = RMSNorm(dim, eps=norm_eps) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.dim = dim + self.n_heads = n_heads + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False # TODO: not yet supported + self.blocks_to_swap = None # TODO: not yet supported + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.t_embedder.enable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + self.final_layer.enable_gradient_checkpointing() + + print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.t_embedder.disable_gradient_checkpointing() + + for block in self.layers + self.context_refiner + self.noise_refiner: + block.disable_gradient_checkpointing() + + self.final_layer.disable_gradient_checkpointing() + + print("Lumina: Gradient checkpointing disabled.") + + def unpatchify( + self, + x: Tensor, + width: int, + height: int, + encoder_seq_lengths: List[int], + seq_lengths: List[int], + ) -> Tensor: + """ + Unpatchify the input tensor and embed the caption features. + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + + Args: + x (Tensor): Input tensor. + width (int): Width of the input tensor. + height (int): Height of the input tensor. + encoder_seq_lengths (List[int]): List of encoder sequence lengths. + seq_lengths (List[int]): List of sequence lengths + + Returns: + output: (N, C, H, W) + """ + pH = pW = self.patch_size + + output = [] + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + output.append( + x[i][encoder_seq_len:seq_len] + .view(height // pH, width // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + output = torch.stack(output, dim=0) + + return output + + def patchify_and_embed( + self, + x: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + t: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + """ + Patchify and embed the input image and caption features. + + Args: + x: (N, C, H, W) image latents + cap_feats: (N, C, D) caption features + cap_mask: (N, C, D) caption attention mask + t: (N), T timesteps + + Returns: + Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: + + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths + """ + bsz, channels, height, width = x.shape + pH = pW = self.patch_size + device = x.device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + encoder_seq_len = cap_mask.shape[1] + image_seq_len = (height // self.patch_size) * (width // self.patch_size) + + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) + + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + H_tokens, W_tokens = height // pH, width // pW + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len:seq_len, 0] = cap_len + + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + + position_ids[i, cap_len:seq_len, 1] = row_ids + position_ids[i, cap_len:seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self.rope_embedder(position_ids) + + # Create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + bsz, + encoder_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + img_freqs_cis = torch.zeros( + bsz, + image_seq_len, + freqs_cis.shape[-1], + device=device, + dtype=freqs_cis.dtype, + ) + + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len] + + # Refine caption context + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + + x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device) + for i in range(bsz): + x[i, :image_seq_len] = x[i] + x_mask[i, :image_seq_len] = True + + x = self.x_embedder(x) + + # Refine image context + for layer in self.noise_refiner: + x = layer(x, x_mask, img_freqs_cis, t) + + joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype) + attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len] + joint_hidden_states[i, cap_len:seq_len] = x[i] + + x = joint_hidden_states + + return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths + + def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor: + """ + Forward pass of NextDiT. + Args: + x: (N, C, H, W) image latents + t: (N,) tensor of diffusion timesteps + cap_feats: (N, L, D) caption features + cap_mask: (N, L) caption attention mask + + Returns: + x: (N, C, H, W) denoised latents + """ + _, _, height, width = x.shape # B, C, H, W + t = self.t_embedder(t) # (N, D) + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t) + + if not self.blocks_to_swap: + for layer in self.layers: + x = layer(x, mask, freqs_cis, t) + else: + for block_idx, layer in enumerate(self.layers): + self.offloader_main.wait_for_block(block_idx) + + x = layer(x, mask, freqs_cis, t) + + self.offloader_main.submit_move_blocks(self.layers, block_idx) + + x = self.final_layer(x, t) + x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths) + + return x + + def forward_with_cfg( + self, + x: Tensor, + t: Tensor, + cap_feats: Tensor, + cap_mask: Tensor, + cfg_scale: float, + cfg_trunc: float = 0.25, + renorm_cfg: float = 1.0, + ): + """ + Forward pass of NextDiT, but also batches the unconditional forward pass + for classifier-free guidance. + """ + # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + if t[0] < cfg_trunc: + combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128] + assert ( + cap_mask.shape[0] == combined.shape[0] + ), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}" + model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128] + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + if float(renorm_cfg) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True) + max_new_norm = ori_pos_norm * float(renorm_cfg) + new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True) + if new_pos_norm >= max_new_norm: + half_eps = half_eps * (max_new_norm / new_pos_norm) + else: + combined = half + model_out = self.forward( + combined, + t[: len(x) // 2], + cap_feats[: len(x) // 2], + cap_mask[: len(x) // 2], + ) + eps, rest = ( + model_out[:, : self.in_channels], + model_out[:, self.in_channels :], + ) + half_eps = eps + + output = torch.cat([half_eps, half_eps], dim=0) + return output + + @staticmethod + def precompute_freqs_cis( + dim: List[int], + end: List[int], + theta: float = 10000.0, + ) -> List[Tensor]: + """ + Precompute the frequency tensor for complex exponentials (cis) with + given dimensions. + + This function calculates a frequency tensor with complex exponentials + using the given dimension 'dim' and the end index 'end'. The 'theta' + parameter scales the frequencies. The returned tensor contains complex + values in complex64 data type. + + Args: + dim (list): Dimension of the frequency tensor. + end (list): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. + Defaults to 10000.0. + + Returns: + List[torch.Tensor]: Precomputed frequency tensor with complex + exponentials. + """ + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + for i, (d, e) in enumerate(zip(dim, end)): + pos = torch.arange(e, dtype=freqs_dtype, device="cpu") + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d)) + freqs = torch.outer(pos, freqs) + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2] + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def parameter_count(self) -> int: + total_params = 0 + + def _recursive_count_params(module): + nonlocal total_params + for param in module.parameters(recurse=False): + total_params += param.numel() + for submodule in module.children(): + _recursive_count_params(submodule) + + _recursive_count_params(self) + return total_params + + def get_fsdp_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + def get_checkpointing_wrap_module_list(self) -> List[nn.Module]: + return list(self.layers) + + def enable_block_swap(self, blocks_to_swap: int, device: torch.device): + """ + Enable block swapping to reduce memory usage during inference. + + Args: + num_blocks (int): Number of blocks to swap between CPU and device + device (torch.device): Device to use for computation + """ + self.blocks_to_swap = blocks_to_swap + + # Calculate how many blocks to swap from main layers + + assert blocks_to_swap <= len(self.layers) - 2, ( + f"Cannot swap more than {len(self.layers) - 2} main blocks. " + f"Requested {blocks_to_swap} blocks." + ) + + self.offloader_main = custom_offloading_utils.ModelOffloader( + self.layers, blocks_to_swap, device, debug=False + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + """ + Move the model to the device except for blocks that will be swapped. + This reduces temporary memory usage during model loading. + + Args: + device (torch.device): Device to move the model to + """ + if self.blocks_to_swap: + save_layers = self.layers + self.layers = nn.ModuleList([]) + + self.to(device) + + if self.blocks_to_swap: + self.layers = save_layers + + def prepare_block_swap_before_forward(self): + """ + Prepare blocks for swapping before forward pass. + """ + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + self.offloader_main.prepare_block_devices_before_forward(self.layers) + + +############################################################################# +# NextDiT Configs # +############################################################################# + + +def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs): + if params is None: + params = LuminaParams.get_2b_config() + + return NextDiT( + patch_size=params.patch_size, + in_channels=params.in_channels, + dim=params.dim, + n_layers=params.n_layers, + n_heads=params.n_heads, + n_kv_heads=params.n_kv_heads, + axes_dims=params.axes_dims, + axes_lens=params.axes_lens, + qk_norm=params.qk_norm, + ffn_dim_multiplier=params.ffn_dim_multiplier, + norm_eps=params.norm_eps, + cap_feat_dim=params.cap_feat_dim, + **kwargs, + ) + + +def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2592, + n_layers=30, + n_heads=24, + n_kv_heads=8, + axes_dims=[36, 36, 36], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=2880, + n_layers=32, + n_heads=24, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) + + +def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs): + return NextDiT( + patch_size=2, + dim=3840, + n_layers=32, + n_heads=32, + n_kv_heads=8, + axes_dims=[40, 40, 40], + axes_lens=[300, 512, 512], + **kwargs, + ) diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py new file mode 100644 index 000000000..d5d5db05f --- /dev/null +++ b/library/lumina_train_util.py @@ -0,0 +1,1099 @@ +import inspect +import argparse +import math +import os +import numpy as np +import time +from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator + +import torch +from torch import Tensor +from accelerate import Accelerator, PartialState +from transformers import Gemma2Model +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import lumina_models, strategy_base, strategy_lumina, train_util +from library.flux_models import AutoEncoder +from library.device_utils import init_ipex, clean_memory_on_device +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.safetensors_utils import mem_eff_save_file + +init_ipex() + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def batchify( + prompt_dicts, batch_size=None +) -> Generator[list[dict[str, str]], None, None]: + """ + Group prompt dictionaries into batches with configurable batch size. + + Args: + prompt_dicts (list): List of dictionaries containing prompt parameters. + batch_size (int, optional): Number of prompts per batch. Defaults to None. + + Yields: + list[dict[str, str]]: Batch of prompts. + """ + # Validate batch_size + if batch_size is not None: + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size must be a positive integer or None") + + # Group prompts by their parameters + batches = {} + for prompt_dict in prompt_dicts: + # Extract parameters + width = int(prompt_dict.get("width", 1024)) + height = int(prompt_dict.get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + guidance_scale = float(prompt_dict.get("scale", 3.5)) + sample_steps = int(prompt_dict.get("sample_steps", 38)) + cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0)) + seed = prompt_dict.get("seed", None) + seed = int(seed) if seed is not None else None + + # Create a key based on the parameters + key = ( + width, + height, + guidance_scale, + seed, + sample_steps, + cfg_trunc_ratio, + renorm_cfg, + ) + + # Add the prompt_dict to the corresponding batch + if key not in batches: + batches[key] = [] + batches[key].append(prompt_dict) + + # Yield each batch with its parameters + for key in batches: + prompts = batches[key] + if batch_size is None: + # Yield the entire group as a single batch + yield prompts + else: + # Split the group into batches of size `batch_size` + start = 0 + while start < len(prompts): + end = start + batch_size + batch = prompts[start:end] + yield batch + start = end + + +@torch.no_grad() +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch: int, + global_step: int, + nextdit: lumina_models.NextDiT, + vae: AutoEncoder, + gemma2_model: Gemma2Model, + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, +): + """ + Generate sample images using the NextDiT model. + + Args: + accelerator (Accelerator): Accelerator instance. + args (argparse.Namespace): Command-line arguments. + epoch (int): Current epoch number. + global_step (int): Current global step number. + nextdit (lumina_models.NextDiT): The NextDiT model instance. + vae (AutoEncoder): The VAE module. + gemma2_model (Gemma2Model): The Gemma2 model instance. + sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]): + Dictionary of tuples containing the encoded prompts, text masks, and timestep for each sample. + prompt_replacement (Optional[Tuple[str, str]], optional): + Tuple containing the prompt and negative prompt replacements. Defaults to None. + controlnet (): ControlNet model, not yet supported + + Returns: + None + """ + if global_step == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if ( + global_step % args.sample_every_n_steps != 0 or epoch is not None + ): # steps is not divisible or end of epoch + return + + assert ( + args.sample_prompts is not None + ), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください" + + logger.info("") + logger.info( + f"generating sample images at step / サンプル画像生成 ステップ: {global_step}" + ) + if ( + not os.path.isfile(args.sample_prompts) + and sample_prompts_gemma2_outputs is None + ): + logger.error( + f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}" + ) + return + + distributed_state = ( + PartialState() + ) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap nextdit and gemma2_model + nextdit = accelerator.unwrap_model(nextdit) + if gemma2_model is not None: + gemma2_model = accelerator.unwrap_model(gemma2_model) + # if controlnet is not None: + # controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = ( + torch.cuda.get_rng_state() if torch.cuda.is_available() else None + ) + except Exception: + pass + + batch_size = args.sample_batch_size or args.train_batch_size or 1 + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompts, batch_size): + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dicts, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with distributed_state.split_between_processes( + per_process_prompts + ) as prompt_dict_lists: + # TODO: batch prompts together with buckets of image sizes + for prompt_dicts in batchify(prompt_dict_lists[0], batch_size): + sample_image_inference( + accelerator, + args, + nextdit, + gemma2_model, + vae, + save_dir, + prompt_dicts, + epoch, + global_step, + sample_prompts_gemma2_outputs, + prompt_replacement, + controlnet, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +@torch.no_grad() +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + nextdit: lumina_models.NextDiT, + gemma2_model: list[Gemma2Model], + vae: AutoEncoder, + save_dir: str, + prompt_dicts: list[Dict[str, str]], + epoch: int, + global_step: int, + sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]], + prompt_replacement: Optional[Tuple[str, str]] = None, + controlnet=None, +): + """ + Generates sample images + + Args: + accelerator (Accelerator): Accelerator object + args (argparse.Namespace): Arguments object + nextdit (lumina_models.NextDiT): NextDiT model + gemma2_model (list[Gemma2Model]): Gemma2 model + vae (AutoEncoder): VAE model + save_dir (str): Directory to save images + prompt_dict (Dict[str, str]): Prompt dictionary + epoch (int): Epoch number + steps (int): Number of steps to run + sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs + prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None. + + Returns: + None + """ + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + + text_conds = [] + + # assuming seed, width, height, sample steps, guidance are the same + width = int(prompt_dicts[0].get("width", 1024)) + height = int(prompt_dicts[0].get("height", 1024)) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + + guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0)) + sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) + seed = prompt_dicts[0].get("seed", None) + seed = int(seed) if seed is not None else None + assert seed is None or seed > 0, f"Invalid seed {seed}" + generator = torch.Generator(device=accelerator.device) + if seed is not None: + generator.manual_seed(seed) + + for prompt_dict in prompt_dicts: + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + negative_prompt = prompt_dict.get("negative_prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace( + prompt_replacement[0], prompt_replacement[1] + ) + + if negative_prompt is None: + negative_prompt = "" + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {guidance_scale}") + logger.info(f"trunc: {cfg_trunc_ratio}") + logger.info(f"renorm: {renorm_cfg}") + # logger.info(f"sample_sampler: {sampler_name}") + + + # No need to add system prompt here, as it has been handled in the tokenize_strategy + + # Get sample prompts from cache + if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs: + gemma2_conds = sample_prompts_gemma2_outputs[prompt] + logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}") + + if ( + sample_prompts_gemma2_outputs + and negative_prompt in sample_prompts_gemma2_outputs + ): + neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt] + logger.info( + f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}" + ) + + # Load sample prompts from Gemma 2 + if gemma2_model is not None: + tokens_and_masks = tokenize_strategy.tokenize(prompt) + gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, gemma2_model, tokens_and_masks + ) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) + neg_gemma2_conds = encoding_strategy.encode_tokens( + tokenize_strategy, gemma2_model, tokens_and_masks + ) + + # Unpack Gemma2 outputs + gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds + neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds + + text_conds.append( + ( + gemma2_hidden_states.squeeze(0), + gemma2_attn_mask.squeeze(0), + neg_gemma2_hidden_states.squeeze(0), + neg_gemma2_attn_mask.squeeze(0), + ) + ) + + # Stack conditioning + cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to( + accelerator.device + ) + cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to( + accelerator.device + ) + uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to( + accelerator.device + ) + + # sample image + weight_dtype = vae.dtype # TOFO give dtype as argument + latent_height = height // 8 + latent_width = width // 8 + latent_channels = 16 + noise = torch.randn( + 1, + latent_channels, + latent_height, + latent_width, + device=accelerator.device, + dtype=weight_dtype, + generator=generator, + ) + noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1) + + scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0) + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps=sample_steps + ) + + # if controlnet_image is not None: + # controlnet_image = Image.open(controlnet_image).convert("RGB") + # controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(): + x = denoise( + scheduler, + nextdit, + noise, + cond_hidden_states, + cond_attn_masks, + uncond_hidden_states, + uncond_attn_masks, + timesteps=timesteps, + guidance_scale=guidance_scale, + cfg_trunc_ratio=cfg_trunc_ratio, + renorm_cfg=renorm_cfg, + ) + + # Latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) # distributed_state.device is same as accelerator.device + for img, prompt_dict in zip(x, prompt_dicts): + + img = (img / vae.scale_factor) + vae.shift_factor + + with accelerator.autocast(): + # Add a single batch image for the VAE to decode + img = vae.decode(img.unsqueeze(0)) + + img = img.clamp(-1, 1) + img = img.permute(0, 2, 3, 1) # B, H, W, C + # Scale images back to 0 to 255 + img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8) + + # Get single image + image = Image.fromarray(img[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = int(prompt_dict.get("enum", 0)) + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False + ) # positive prompt as a caption + + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + # the following implementation was original for t=0: clean / t=1: noise + # Since we adopt the reverse, the 1-t operations are needed + t = 1 - t + t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + t = 1 - t + return t + + +def get_lin_function( + x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15 +) -> Callable[[float], float]: + """ + Get linear function + + Args: + image_seq_len, + x1 base_seq_len: int = 256, + y2 max_seq_len: int = 4096, + y1 base_shift: float = 0.5, + y2 max_shift: float = 1.15, + + Return: + Callable[[float], float]: linear function + """ + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + """ + Get timesteps schedule + + Args: + num_steps (int): Number of steps in the schedule. + image_seq_len (int): Sequence length of the image. + base_shift (float, optional): Base shift value. Defaults to 0.5. + max_shift (float, optional): Maximum shift value. Defaults to 1.15. + shift (bool, optional): Whether to shift the schedule. Defaults to True. + + Return: + List[float]: timesteps schedule + """ + timesteps = torch.linspace(1, 1 / num_steps, num_steps) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)( + image_seq_len + ) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +) -> Tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +def denoise( + scheduler, + model: lumina_models.NextDiT, + img: Tensor, + txt: Tensor, + txt_mask: Tensor, + neg_txt: Tensor, + neg_txt_mask: Tensor, + timesteps: Union[List[float], torch.Tensor], + guidance_scale: float = 4.0, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, +): + """ + Denoise an image using the NextDiT model. + + Args: + scheduler (): + Noise scheduler + model (lumina_models.NextDiT): The NextDiT model instance. + img (Tensor): + The input image latent tensor. + txt (Tensor): + The input text tensor. + txt_mask (Tensor): + The input text mask tensor. + neg_txt (Tensor): + The negative input txt tensor + neg_txt_mask (Tensor): + The negative input text mask tensor. + timesteps (List[Union[float, torch.FloatTensor]]): + A list of timesteps for the denoising process. + guidance_scale (float, optional): + The guidance scale for the denoising process. Defaults to 4.0. + cfg_trunc_ratio (float, optional): + The ratio of the timestep interval to apply normalization-based guidance scale. + renorm_cfg (float, optional): + The factor to limit the maximum norm after guidance. Default: 1.0 + Returns: + img (Tensor): Denoised latent tensor + """ + + for i, t in enumerate(tqdm(timesteps)): + model.prepare_block_swap_before_forward() + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep * torch.ones( + img.shape[0], device=img.device + ) + + noise_pred_cond = model( + img, + current_timestep, + cap_feats=txt, # Gemma2的hidden states作为caption features + cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + + # compute whether to apply classifier-free guidance based on current timestep + if current_timestep[0] < cfg_trunc_ratio: + model.prepare_block_swap_before_forward() + noise_pred_uncond = model( + img, + current_timestep, + cap_feats=neg_txt, # Gemma2的hidden states作为caption features + cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + # apply normalization after classifier-free guidance + if float(renorm_cfg) > 0.0: + cond_norm = torch.linalg.vector_norm( + noise_pred_cond, + dim=tuple(range(1, len(noise_pred_cond.shape))), + keepdim=True, + ) + max_new_norms = cond_norm * float(renorm_cfg) + noise_norms = torch.linalg.vector_norm( + noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True + ) + # Iterate through batch + for i, (noise_norm, max_new_norm) in enumerate(zip(noise_norms, max_new_norms)): + if noise_norm >= max_new_norm: + noise_pred[i] = noise_pred[i] * (max_new_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + img_dtype = img.dtype + + if img.dtype != img_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + img = img.to(img_dtype) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = -noise_pred + img = scheduler.step(noise_pred, t, img, return_dict=False)[0] + + model.prepare_block_swap_before_forward() + return img + + +# endregion + + +# region train +def get_sigmas( + noise_scheduler: FlowMatchEulerDiscreteScheduler, + timesteps: Tensor, + device: torch.device, + n_dim=4, + dtype=torch.float32, +) -> Tensor: + """ + Get sigmas for timesteps + + Args: + noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance. + timesteps (Tensor): A tensor of timesteps for the denoising process. + device (torch.device): The device on which the tensors are stored. + n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4. + dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32. + + Returns: + sigmas (Tensor): The sigmas tensor. + """ + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, +): + """ + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + batch_size (int): The batch size for the sampling process. + logit_mean (float, optional): The mean of the logit distribution. Defaults to None. + logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None. + mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal( + mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu" + ) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor: + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + + Args: + weighting_scheme (str): The weighting scheme to use. + sigmas (Tensor, optional): The sigmas tensor. Defaults to None. + + Returns: + u (Tensor): The sampled timesteps. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Get noisy model input and timesteps. + + Args: + args (argparse.Namespace): Arguments. + noise_scheduler (noise_scheduler): Noise scheduler. + latents (Tensor): Latents. + noise (Tensor): Latent noise. + device (torch.device): Device. + dtype (torch.dtype): Data type + + Return: + Tuple[Tensor, Tensor, Tensor]: + noisy model input + timesteps + sigmas + """ + bsz, _, h, w = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = ( + logits_norm * args.sigmoid_scale + ) # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * noise + t * latents + elif args.timestep_sampling == "nextdit_shift": + t = torch.rand((bsz,), device=device) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + t = time_shift(mu, 1.0, t) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * noise + t * latents + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas( + noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype + ) + noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type( + args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor +) -> Tuple[Tensor, Optional[Tensor]]: + """ + Apply model prediction type to the model prediction and the sigmas. + + Args: + args (argparse.Namespace): Arguments. + model_pred (Tensor): Model prediction. + noisy_model_input (Tensor): Noisy model input. + sigmas (Tensor): Sigmas. + + Return: + Tuple[Tensor, Optional[Tensor]]: + """ + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=args.weighting_scheme, sigmas=sigmas + ) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + lumina: lumina_models.NextDiT, + sai_metadata: Dict[str, Any], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + """ + Save the model to the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + lumina (lumina_models.NextDiT): NextDIT model. + sai_metadata (Optional[dict]): Metadata for the SAI model. + save_dtype (Optional[torch.dtype]): Data + + Return: + None + """ + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", lumina.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_lumina_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + lumina: lumina_models.NextDiT, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", + ) + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common( + args, True, True, epoch, global_step, sd_saver, None + ) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_lumina_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator: Accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + lumina: lumina_models.NextDiT, +): + """ + Save the model to the checkpoint path. + + Args: + args (argparse.Namespace): Arguments. + save_dtype (torch.dtype): Data type. + epoch (int): Epoch. + global_step (int): Global step. + lumina (lumina_models.NextDiT): NextDIT model. + + Return: + None + """ + + def sd_saver(ckpt_file: str, epoch_no: int, global_step: int): + sai_metadata = train_util.get_sai_model_spec( + {}, + args, + False, + False, + False, + is_stable_diffusion_ckpt=True, + lumina="lumina2", + ) + save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_lumina_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--gemma2", + type=str, + help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--ae", + type=str, + help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=None, + help="maximum token length for Gemma2. if omitted, 256" + " / Gemma2の最大トークン長。省略された場合、256になります", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], + default="shift", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="raw", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=6.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use Flash Attention for the model / モデルにFlash Attentionを使用する", + ) + parser.add_argument( + "--use_sage_attn", + action="store_true", + help="Use Sage Attention for the model / モデルにSage Attentionを使用する", + ) + parser.add_argument( + "--system_prompt", + type=str, + default="", + help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=None, + help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます", + ) diff --git a/library/lumina_util.py b/library/lumina_util.py new file mode 100644 index 000000000..f7f3c8231 --- /dev/null +++ b/library/lumina_util.py @@ -0,0 +1,259 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import Gemma2Config, Gemma2Model + +from library.utils import setup_logging +from library import lumina_models, flux_models +from library.safetensors_utils import load_safetensors +import logging + +setup_logging() +logger = logging.getLogger(__name__) + +MODEL_VERSION_LUMINA_V2 = "lumina2" + + +def load_lumina_model( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: torch.device, + disable_mmap: bool = False, + use_flash_attn: bool = False, + use_sage_attn: bool = False, +): + """ + Load the Lumina model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (torch.device): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False. + + Returns: + model (lumina_models.NextDiT): The loaded model. + """ + logger.info("Building Lumina") + with torch.device("meta"): + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to( + dtype + ) + + logger.info(f"Loading state dict from {ckpt_path}") + state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "model.diffusion_model.cap_embedder.0.weight" in state_dict: + # remove "model.diffusion_model." prefix + filtered_state_dict = { + k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.") + } + state_dict = filtered_state_dict + + info = model.load_state_dict(state_dict, strict=False, assign=True) + logger.info(f"Loaded Lumina: {info}") + return model + + +def load_ae( + ckpt_path: str, + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, +) -> flux_models.AutoEncoder: + """ + Load the AutoEncoder model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + + Returns: + ae (flux_models.AutoEncoder): The loaded model. + """ + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "vae.decoder.conv_in.bias" in sd: + # remove "vae." prefix + filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")} + sd = filtered_sd + + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_gemma2( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Gemma2Model: + """ + Load the Gemma2 model from the checkpoint path. + + Args: + ckpt_path (str): Path to the checkpoint. + dtype (torch.dtype): The data type for the model. + device (Union[str, torch.device]): The device to load the model on. + disable_mmap (bool, optional): Whether to disable mmap. Defaults to False. + state_dict (Optional[dict], optional): The state dict to load. Defaults to None. + + Returns: + gemma2 (Gemma2Model): The loaded model + """ + logger.info("Building Gemma2") + GEMMA2_CONFIG = { + "_name_or_path": "google/gemma-2-2b", + "architectures": ["Gemma2Model"], + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": 50.0, + "bos_token_id": 2, + "cache_implementation": "hybrid", + "eos_token_id": 1, + "final_logit_softcapping": 30.0, + "head_dim": 256, + "hidden_act": "gelu_pytorch_tanh", + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2304, + "initializer_range": 0.02, + "intermediate_size": 9216, + "max_position_embeddings": 8192, + "model_type": "gemma2", + "num_attention_heads": 8, + "num_hidden_layers": 26, + "num_key_value_heads": 4, + "pad_token_id": 0, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "sliding_window": 4096, + "torch_dtype": "float32", + "transformers_version": "4.44.2", + "use_cache": True, + "vocab_size": 256000, + } + + config = Gemma2Config(**GEMMA2_CONFIG) + with init_empty_weights(): + gemma2 = Gemma2Model._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + for key in list(sd.keys()): + new_key = key.replace("model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + # Neta-Lumina support + if "text_encoders.gemma2_2b.logit_scale" in sd: + # remove "text_encoders.gemma2_2b.transformer.model." prefix + filtered_sd = { + k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v + for k, v in sd.items() + if k.startswith("text_encoders.gemma2_2b.transformer.model.") + } + sd = filtered_sd + + info = gemma2.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Gemma2: {info}") + return gemma2 + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = { + # Embedding layers + "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight", + "time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight", + "text_embedder.1.bias": "cap_embedder.1.bias", + "patch_embedder.proj.weight": "x_embedder.weight", + "patch_embedder.proj.bias": "x_embedder.bias", + # Attention modulation + "transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight", + "transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias", + # Final layers + "final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight", + "final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias", + "final_linear.weight": "final_layer.linear.weight", + "final_linear.bias": "final_layer.linear.bias", + # Noise refiner + "single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight", + "single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias", + "single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight", + "single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight", + # Normalization + "transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight", + "transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight", + # FFN + "transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight", + "transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight", + "transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight", +} + + +def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict: + """Convert Diffusers checkpoint to Alpha-VLLM format""" + logger.info("Converting Diffusers checkpoint to Alpha-VLLM format") + new_sd = sd.copy() # Preserve original keys + + for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): + # Handle block-specific patterns + if "()." in diff_key: + for block_idx in range(num_double_blocks): + block_alpha_key = alpha_key.replace("().", f"{block_idx}.") + block_diff_key = diff_key.replace("().", f"{block_idx}.") + + # Search for and convert block-specific keys + for input_key, value in list(sd.items()): + if input_key == block_diff_key: + new_sd[block_alpha_key] = value + else: + # Handle static keys + if diff_key in sd: + print(f"Replacing {diff_key} with {alpha_key}") + new_sd[alpha_key] = sd[diff_key] + else: + print(f"Not found: {diff_key}") + + logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") + return new_sd diff --git a/library/model_util.py b/library/model_util.py index 9918c7b2a..bcaa1145b 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -6,6 +6,7 @@ import torch from library.device_utils import init_ipex + init_ipex() import diffusers @@ -14,8 +15,10 @@ from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ @@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): checkpoint = None state_dict = load_file(ckpt_path) # , device) # may causes error else: - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: diff --git a/library/original_unet.py b/library/original_unet.py index e944ff22b..aa9dc233b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -114,8 +114,10 @@ from torch.nn import functional as F from einops import rearrange from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) @@ -530,7 +532,9 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -626,15 +630,9 @@ def forward(self, hidden_states, context=None, mask=None, **kwargs): hidden_states, encoder_hidden_states, attention_mask, - ) = translate_attention_names_from_diffusers( - hidden_states=hidden_states, context=context, mask=mask, **kwargs - ) + ) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs) return self.processor( - attn=self, - hidden_states=hidden_states, - encoder_hidden_states=context, - attention_mask=mask, - **kwargs + attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) @@ -748,13 +746,14 @@ def forward_sdpa(self, x, context=None, mask=None): out = self.to_out[0](out) return out + def translate_attention_names_from_diffusers( hidden_states: torch.FloatTensor, context: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, # HF naming encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None + attention_mask: Optional[torch.FloatTensor] = None, ): # translate from hugging face diffusers context = context if context is not None else encoder_hidden_states @@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers( return hidden_states, context, mask + # feedforward class GEGLU(nn.Module): r""" @@ -1015,9 +1015,11 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1098,10 +1100,12 @@ def custom_forward(*inputs): if attn is not None: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: if attn is not None: hidden_states = attn(hidden_states, encoder_hidden_states).sample @@ -1201,7 +1205,9 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -1296,9 +1302,11 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) diff --git a/library/safetensors_utils.py b/library/safetensors_utils.py new file mode 100644 index 000000000..c65cdfabe --- /dev/null +++ b/library/safetensors_utils.py @@ -0,0 +1,351 @@ +import os +import re +import numpy as np +import torch +import json +import struct +from typing import Dict, Any, Union, Optional + +from safetensors.torch import load_file + +from library.device_utils import synchronize_device + + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Dict[str, str]: + """Get metadata from the file. + + Returns: + Dict[str, str]: Metadata dictionary. + """ + return self.header.get("__metadata__", {}) + + def _read_header(self): + """Read and parse the header from the safetensors file. + + Returns: + tuple: (header_dict, header_size) containing parsed header and its size. + """ + # Read header size (8 bytes, little-endian unsigned long long) + header_size = struct.unpack("10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies. + + Args: + key (str): Name of the tensor to load. + device (Optional[torch.device]): Target device for the tensor. + dtype (Optional[torch.dtype]): Target dtype for the tensor. + + Returns: + torch.Tensor: The loaded tensor. + + Raises: + KeyError: If the tensor key is not found in the file. + """ + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + num_bytes = offset_end - offset_start + + original_dtype = self._get_torch_dtype(metadata["dtype"]) + target_dtype = dtype if dtype is not None else original_dtype + + # Handle empty tensors + if num_bytes == 0: + return torch.empty(metadata["shape"], dtype=target_dtype, device=device) + + # Determine if we should use pinned memory for GPU transfer + non_blocking = device is not None and device.type == "cuda" + + # Calculate absolute file offset + tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size + + # Memory mapping strategy for large tensors to GPU + # Use memmap for large tensors to avoid intermediate copies. + # If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired. + # So we only use memmap if device is not cpu. + if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu": + # Create memory map for zero-copy reading + mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,)) + byte_tensor = torch.from_numpy(mm) # zero copy + del mm + + # Deserialize tensor (view and reshape) + cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape + del byte_tensor + + # Transfer to target device and dtype + gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) + del cpu_tensor + return gpu_tensor + + # Standard file reading strategy for smaller tensors or CPU target + # seek to the specified position + self.file.seek(tensor_offset) + + # read directly into a numpy array by numpy.fromfile without intermediate copy + numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes) + byte_tensor = torch.from_numpy(numpy_array) + del numpy_array + + # deserialize (view and reshape) + deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata) + del byte_tensor + + # cast to target dtype and move to device + return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) + + def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict): + """Deserialize byte tensor to the correct shape and dtype. + + Args: + byte_tensor (torch.Tensor): Raw byte tensor from file. + metadata (Dict): Tensor metadata containing dtype and shape info. + + Returns: + torch.Tensor: Deserialized tensor with correct shape and dtype. + """ + dtype = self._get_torch_dtype(metadata["dtype"]) + shape = metadata["shape"] + + # Handle special float8 types + if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: + return self._convert_float8(byte_tensor, metadata["dtype"], shape) + + # Standard conversion: view as target dtype and reshape + return byte_tensor.view(dtype).reshape(shape) + + @staticmethod + def _get_torch_dtype(dtype_str): + """Convert string dtype to PyTorch dtype. + + Args: + dtype_str (str): String representation of the dtype. + + Returns: + torch.dtype: Corresponding PyTorch dtype. + """ + # Standard dtype mappings + dtype_map = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + } + # Add float8 types if available in PyTorch version + if hasattr(torch, "float8_e5m2"): + dtype_map["F8_E5M2"] = torch.float8_e5m2 + if hasattr(torch, "float8_e4m3fn"): + dtype_map["F8_E4M3"] = torch.float8_e4m3fn + return dtype_map.get(dtype_str) + + @staticmethod + def _convert_float8(byte_tensor, dtype_str, shape): + """Convert byte tensor to float8 format if supported. + + Args: + byte_tensor (torch.Tensor): Raw byte tensor. + dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3"). + shape (tuple): Target tensor shape. + + Returns: + torch.Tensor: Tensor with float8 dtype. + + Raises: + ValueError: If float8 type is not supported in current PyTorch version. + """ + # Convert to specific float8 types if available + if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): + return byte_tensor.view(torch.float8_e5m2).reshape(shape) + elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): + return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) + else: + # Float8 not supported in this PyTorch version + raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + + +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None +) -> dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + device = torch.device(device) if device is not None else None + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key, device=device, dtype=dtype) + synchronize_device(device) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + +def load_split_weights( + file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None +) -> Dict[str, torch.Tensor]: + """ + Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix. + dtype is as is, no conversion is done. + """ + device = torch.device(device) + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + basename = os.path.basename(file_path) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + state_dict = {} + for i in range(count): + filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors" + filepath = os.path.join(os.path.dirname(file_path), filename) + if os.path.exists(filepath): + state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype)) + else: + raise FileNotFoundError(f"File {filepath} not found") + else: + state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + return state_dict + + +def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with: Optional[str] = None) -> Optional[str]: + """ + Find a key in a safetensors file that starts with `starts_with` and ends with `ends_with`. + If `starts_with` is None, it will match any key. + If `ends_with` is None, it will match any key. + Returns the first matching key or None if no key matches. + """ + with MemoryEfficientSafeOpen(safetensors_file) as f: + for key in f.keys(): + if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)): + return key + return None diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82ec..32a4fd7bf 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -1,13 +1,20 @@ # based on https://github.com/Stability-AI/ModelSpec import datetime import hashlib +import argparse +import base64 +import logging +import mimetypes +import subprocess +from dataclasses import dataclass, field from io import BytesIO import os -from typing import List, Optional, Tuple, Union +from typing import Union import safetensors from library.utils import setup_logging + setup_logging() -import logging + logger = logging.getLogger(__name__) r""" @@ -29,23 +36,32 @@ """ BASE_METADATA = { - # === Must === - "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + # === MUST === + "modelspec.sai_model_spec": "1.0.1", "modelspec.architecture": None, "modelspec.implementation": None, "modelspec.title": None, "modelspec.resolution": None, - # === Should === + # === SHOULD === "modelspec.description": None, "modelspec.author": None, "modelspec.date": None, - # === Can === + "modelspec.hash_sha256": None, + # === CAN=== + "modelspec.implementation_version": None, "modelspec.license": None, + "modelspec.usage_hint": None, + "modelspec.thumbnail": None, "modelspec.tags": None, "modelspec.merged_from": None, + "modelspec.trigger_phrase": None, "modelspec.prediction_type": None, "modelspec.timestep_range": None, "modelspec.encoder_layer": None, + "modelspec.preprocessor": None, + "modelspec.is_negative_embedding": None, + "modelspec.unet_dtype": None, + "modelspec.vae_dtype": None, } # 別に使うやつだけ定義 @@ -55,17 +71,270 @@ ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_SCHNELL = "flux-1-schnell" +ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma +ARCH_FLUX_1_UNKNOWN = "flux-1" +ARCH_LUMINA_2 = "lumina-2" +ARCH_LUMINA_UNKNOWN = "lumina" +ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1" +ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma" +IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0" +IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" +@dataclass +class ModelSpecMetadata: + """ + ModelSpec 1.0.1 compliant metadata for safetensors models. + All fields correspond to modelspec.* keys in the final metadata. + """ + + # === MUST === + architecture: str + implementation: str + title: str + resolution: str + sai_model_spec: str = "1.0.1" + + # === SHOULD === + description: str | None = None + author: str | None = None + date: str | None = None + hash_sha256: str | None = None + + # === CAN === + implementation_version: str | None = None + license: str | None = None + usage_hint: str | None = None + thumbnail: str | None = None + tags: str | None = None + merged_from: str | None = None + trigger_phrase: str | None = None + prediction_type: str | None = None + timestep_range: str | None = None + encoder_layer: str | None = None + preprocessor: str | None = None + is_negative_embedding: str | None = None + unet_dtype: str | None = None + vae_dtype: str | None = None + + # === Additional metadata === + additional_fields: dict[str, str] = field(default_factory=dict) + + def to_metadata_dict(self) -> dict[str, str]: + """Convert dataclass to metadata dictionary with modelspec. prefixes.""" + metadata = {} + + # Add all non-None fields with modelspec prefix + for field_name, value in self.__dict__.items(): + if field_name == "additional_fields": + # Handle additional fields separately + for key, val in value.items(): + if key.startswith("modelspec."): + metadata[key] = val + else: + metadata[f"modelspec.{key}"] = val + elif value is not None: + metadata[f"modelspec.{field_name}"] = value + + return metadata + + @classmethod + def from_args(cls, args, **kwargs) -> "ModelSpecMetadata": + """Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields.""" + metadata_fields = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix + field_name = attr_name[9:] # len("metadata_") = 9 + metadata_fields[field_name] = value + + # Handle known standard fields + standard_fields = { + "author": metadata_fields.pop("author", None), + "description": metadata_fields.pop("description", None), + "license": metadata_fields.pop("license", None), + "tags": metadata_fields.pop("tags", None), + } + + # Remove None values + standard_fields = {k: v for k, v in standard_fields.items() if v is not None} + + # Merge with kwargs and remaining metadata fields + all_fields = {**standard_fields, **kwargs} + if metadata_fields: + all_fields["additional_fields"] = metadata_fields + + return cls(**all_fields) + + +def determine_architecture( + v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, model_config: dict[str, str] | None = None +) -> str: + """Determine model architecture string from parameters.""" + + model_config = model_config or {} + + if sdxl: + arch = ARCH_SD_XL_V1_BASE + elif "sd3" in model_config: + arch = ARCH_SD3_M + "-" + model_config["sd3"] + elif "flux" in model_config: + flux_type = model_config["flux"] + if flux_type == "dev": + arch = ARCH_FLUX_1_DEV + elif flux_type == "schnell": + arch = ARCH_FLUX_1_SCHNELL + elif flux_type == "chroma": + arch = ARCH_FLUX_1_CHROMA + else: + arch = ARCH_FLUX_1_UNKNOWN + elif "lumina" in model_config: + lumina_type = model_config["lumina"] + if lumina_type == "lumina2": + arch = ARCH_LUMINA_2 + else: + arch = ARCH_LUMINA_UNKNOWN + elif "hunyuan_image" in model_config: + hunyuan_image_type = model_config["hunyuan_image"] + if hunyuan_image_type == "2.1": + arch = ARCH_HUNYUAN_IMAGE_2_1 + else: + arch = ARCH_HUNYUAN_IMAGE_UNKNOWN + elif v2: + arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512 + else: + arch = ARCH_SD_V1 + + # Add adapter suffix + if lora: + arch += f"/{ADAPTER_LORA}" + elif textual_inversion: + arch += f"/{ADAPTER_TEXTUAL_INVERSION}" + + return arch + + +def determine_implementation( + lora: bool, + textual_inversion: bool, + sdxl: bool, + model_config: dict[str, str] | None = None, + is_stable_diffusion_ckpt: bool | None = None, +) -> str: + """Determine implementation string from parameters.""" + + model_config = model_config or {} + + if "flux" in model_config: + if model_config["flux"] == "chroma": + return IMPL_CHROMA + else: + return IMPL_FLUX + elif "lumina" in model_config: + return IMPL_LUMINA + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + return IMPL_STABILITY_AI + else: + return IMPL_DIFFUSERS + + +def get_implementation_version() -> str: + """Get the current implementation version as sd-scripts/{commit_hash}.""" + try: + # Get the git commit hash + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root + timeout=5, + ) + + if result.returncode == 0: + commit_hash = result.stdout.strip() + return f"sd-scripts/{commit_hash}" + else: + logger.warning("Failed to get git commit hash, using fallback") + return "sd-scripts/unknown" + + except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e: + logger.warning(f"Could not determine git commit: {e}") + return "sd-scripts/unknown" + + +def file_to_data_url(file_path: str) -> str: + """Convert a file path to a data URL for embedding in metadata.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Get MIME type + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type is None: + # Default to binary if we can't detect + mime_type = "application/octet-stream" + + # Read file and encode as base64 + with open(file_path, "rb") as f: + file_data = f.read() + + encoded_data = base64.b64encode(file_data).decode("ascii") + + return f"data:{mime_type};base64,{encoded_data}" + + +def determine_resolution( + reso: Union[int, tuple[int, int]] | None = None, + sdxl: bool = False, + model_config: dict[str, str] | None = None, + v2: bool = False, + v_parameterization: bool = False, +) -> str: + """Determine resolution string from parameters.""" + + model_config = model_config or {} + + if reso is not None: + # Handle comma separated string + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + # Handle single int + if isinstance(reso, int): + reso = (reso, reso) + # Handle single-element tuple + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # Determine default resolution based on model type + if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config: + reso = (1024, 1024) + elif v2 and v_parameterization: + reso = (768, 768) + else: + reso = (512, 512) + + return f"{reso[0]}x{reso[1]}" + + def load_bytes_in_safetensors(tensors): bytes = safetensors.torch.save(tensors) b = BytesIO(bytes) @@ -95,62 +364,42 @@ def update_hash_sha256(metadata: dict, state_dict: dict): raise NotImplementedError -def build_metadata( - state_dict: Optional[dict], +def build_metadata_dataclass( + state_dict: dict | None, v2: bool, v_parameterization: bool, sdxl: bool, lora: bool, textual_inversion: bool, timestamp: float, - title: Optional[str] = None, - reso: Optional[Union[int, Tuple[int, int]]] = None, - is_stable_diffusion_ckpt: Optional[bool] = None, - author: Optional[str] = None, - description: Optional[str] = None, - license: Optional[str] = None, - tags: Optional[str] = None, - merged_from: Optional[str] = None, - timesteps: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, -): - # if state_dict is None, hash is not calculated - - metadata = {} - metadata.update(BASE_METADATA) - - # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する - # if state_dict is not None: - # hash = precalculate_safetensors_hashes(state_dict) - # metadata["modelspec.hash_sha256"] = hash - - if sdxl: - arch = ARCH_SD_XL_V1_BASE - elif v2: - if v_parameterization: - arch = ARCH_SD_V2_768_V - else: - arch = ARCH_SD_V2_512 - else: - arch = ARCH_SD_V1 - - if lora: - arch += f"/{ADAPTER_LORA}" - elif textual_inversion: - arch += f"/{ADAPTER_TEXTUAL_INVERSION}" - - metadata["modelspec.architecture"] = arch + title: str | None = None, + reso: int | tuple[int, int] | None = None, + is_stable_diffusion_ckpt: bool | None = None, + author: str | None = None, + description: str | None = None, + license: str | None = None, + tags: str | None = None, + merged_from: str | None = None, + timesteps: tuple[int, int] | None = None, + clip_skip: int | None = None, + model_config: dict | None = None, + optional_metadata: dict | None = None, +) -> ModelSpecMetadata: + """ + Build ModelSpec 1.0.1 compliant metadata dataclass. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include + """ + + # Use helper functions for complex logic + architecture = determine_architecture(v2, v_parameterization, sdxl, lora, textual_inversion, model_config) if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: - # Stable Diffusion ckpt, TI, SDXL LoRA - impl = IMPL_STABILITY_AI - else: - # v1/v2 LoRA or Diffusers - impl = IMPL_DIFFUSERS - metadata["modelspec.implementation"] = impl + implementation = determine_implementation(lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt) if title is None: if lora: @@ -160,97 +409,150 @@ def build_metadata( else: title = "Checkpoint" title += f"@{timestamp}" - metadata[MODELSPEC_TITLE] = title - - if author is not None: - metadata["modelspec.author"] = author - else: - del metadata["modelspec.author"] - - if description is not None: - metadata["modelspec.description"] = description - else: - del metadata["modelspec.description"] - - if merged_from is not None: - metadata["modelspec.merged_from"] = merged_from - else: - del metadata["modelspec.merged_from"] - - if license is not None: - metadata["modelspec.license"] = license - else: - del metadata["modelspec.license"] - - if tags is not None: - metadata["modelspec.tags"] = tags - else: - del metadata["modelspec.tags"] # remove microsecond from time int_ts = int(timestamp) - # time to iso-8601 compliant date date = datetime.datetime.fromtimestamp(int_ts).isoformat() - metadata["modelspec.date"] = date - if reso is not None: - # comma separated to tuple - if isinstance(reso, str): - reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) - else: - # resolution is defined in dataset, so use default - if sdxl: - reso = 1024 - elif v2 and v_parameterization: - reso = 768 - else: - reso = 512 - if isinstance(reso, int): - reso = (reso, reso) - - metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" + # Use helper function for resolution + resolution = determine_resolution(reso, sdxl, model_config, v2, v_parameterization) - if v_parameterization: - metadata["modelspec.prediction_type"] = PRED_TYPE_V - else: - metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON + # Handle prediction type - Flux models don't use prediction_type + model_config = model_config or {} + prediction_type = None + if "flux" not in model_config: + if v_parameterization: + prediction_type = PRED_TYPE_V + else: + prediction_type = PRED_TYPE_EPSILON + # Handle timesteps + timestep_range = None if timesteps is not None: if isinstance(timesteps, str) or isinstance(timesteps, int): timesteps = (timesteps, timesteps) if len(timesteps) == 1: timesteps = (timesteps[0], timesteps[0]) - metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" - else: - del metadata["modelspec.timestep_range"] + timestep_range = f"{timesteps[0]},{timesteps[1]}" + # Handle encoder layer (clip skip) + encoder_layer = None if clip_skip is not None: - metadata["modelspec.encoder_layer"] = f"{clip_skip}" - else: - del metadata["modelspec.encoder_layer"] + encoder_layer = f"{clip_skip}" + + # TODO: Implement hash calculation when memory-efficient method is available + # hash_sha256 = None + # if state_dict is not None: + # hash_sha256 = precalculate_safetensors_hashes(state_dict) + + # Process thumbnail - convert file path to data URL if needed + processed_optional_metadata = optional_metadata.copy() if optional_metadata else {} + if "thumbnail" in processed_optional_metadata: + thumbnail_value = processed_optional_metadata["thumbnail"] + # Check if it's already a data URL or if it's a file path + if thumbnail_value and not thumbnail_value.startswith("data:"): + try: + processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value) + logger.info(f"Converted thumbnail file {thumbnail_value} to data URL") + except FileNotFoundError as e: + logger.warning(f"Thumbnail file not found, skipping: {e}") + del processed_optional_metadata["thumbnail"] + except Exception as e: + logger.warning(f"Failed to convert thumbnail to data URL: {e}") + del processed_optional_metadata["thumbnail"] + + # Automatically set implementation version if not provided + if "implementation_version" not in processed_optional_metadata: + processed_optional_metadata["implementation_version"] = get_implementation_version() + + # Create the dataclass + metadata = ModelSpecMetadata( + architecture=architecture, + implementation=implementation, + title=title, + description=description, + author=author, + date=date, + license=license, + tags=tags, + merged_from=merged_from, + resolution=resolution, + prediction_type=prediction_type, + timestep_range=timestep_range, + encoder_layer=encoder_layer, + additional_fields=processed_optional_metadata, + ) - # # assert all values are filled - # assert all([v is not None for v in metadata.values()]), metadata - if not all([v is not None for v in metadata.values()]): - logger.error(f"Internal error: some metadata values are None: {metadata}") - return metadata +def build_metadata( + state_dict: dict | None, + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + timestamp: float, + title: str | None = None, + reso: int | tuple[int, int] | None = None, + is_stable_diffusion_ckpt: bool | None = None, + author: str | None = None, + description: str | None = None, + license: str | None = None, + tags: str | None = None, + merged_from: str | None = None, + timesteps: tuple[int, int] | None = None, + clip_skip: int | None = None, + model_config: dict | None = None, + optional_metadata: dict | None = None, +) -> dict[str, str]: + """ + Build ModelSpec 1.0.1 compliant metadata for safetensors models. + Legacy function that returns dict - prefer build_metadata_dataclass for new code. + + Args: + model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"} + optional_metadata: Dict of additional metadata fields to include + """ + # Use the dataclass function and convert to dict + metadata_obj = build_metadata_dataclass( + state_dict=state_dict, + v2=v2, + v_parameterization=v_parameterization, + sdxl=sdxl, + lora=lora, + textual_inversion=textual_inversion, + timestamp=timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=author, + description=description, + license=license, + tags=tags, + merged_from=merged_from, + timesteps=timesteps, + clip_skip=clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + return metadata_obj.to_metadata_dict() + + # region utils -def get_title(metadata: dict) -> Optional[str]: +def get_title(metadata: dict) -> str | None: return metadata.get(MODELSPEC_TITLE, None) def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: @@ -258,7 +560,7 @@ def load_metadata_from_safetensors(model: str) -> dict: return metadata -def build_merged_from(models: List[str]) -> str: +def build_merged_from(models: list[str]) -> str: def get_title(model: str): metadata = load_metadata_from_safetensors(model) title = metadata.get(MODELSPEC_TITLE, None) @@ -270,6 +572,77 @@ def get_title(model: str): return ", ".join(titles) +def add_model_spec_arguments(parser: argparse.ArgumentParser): + """Add all ModelSpec metadata arguments to the parser.""" + + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + parser.add_argument( + "--metadata_usage_hint", + type=str, + default=None, + help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント", + ) + parser.add_argument( + "--metadata_thumbnail", + type=str, + default=None, + help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像(データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます)", + ) + parser.add_argument( + "--metadata_merged_from", + type=str, + default=None, + help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名", + ) + parser.add_argument( + "--metadata_trigger_phrase", + type=str, + default=None, + help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ", + ) + parser.add_argument( + "--metadata_preprocessor", + type=str, + default=None, + help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法", + ) + parser.add_argument( + "--metadata_is_negative_embedding", + type=str, + default=None, + help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか", + ) + + # endregion diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 000000000..996f81920 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1428 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from ast import Tuple +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import partial +import math +from types import SimpleNamespace +from typing import Dict, List, Optional, Union +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import custom_offloading_utils +from library.device_utils import clean_memory_on_device + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region mmdit + + +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: list[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=False, + ): + # dynamic_img_pad and norm is omitted in SD3.5 + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: Optional[str] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], + qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, + model_type: str = "sd3m", + ): + super().__init__() + self._model_type = model_type + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) + else: + self.context_embedder = nn.Identity() + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + self.blocks_to_swap = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) + + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() + + # remove duplicates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + + @property + def model_type(self): + return self._model_type + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: # should not happen + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] + + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO + if h > pos_embed_size or w > pos_embed_size: + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size + logger.warning( + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = nn.ModuleList() + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) + + B, C, H, W = x.shape + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 + ) + + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + for block_idx, block in enumerate(self.joint_blocks): + self.offloader.wait_for_block(block_idx) + + context, x = block(context, x, c) + + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) + + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, + in_channels=16, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, + mlp_ratio=4, + qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, + num_patches=params.num_patches, + attn_mode=attn_mode, + model_type=params.model_type, + ) + return mmdit + + +# endregion + +# region VAE + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + # @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + # @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +# endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 000000000..c40798846 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,945 @@ +import argparse +import math +import os +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union + +import torch +from safetensors.torch import save_file +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils, strategy_base, train_util + + +def save_models( + ckpt_path: str, + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + if clip_l is not None: + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) + if clip_g is not None: + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) + if t5xxl is not None: + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + parser.add_argument( + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", + ) + + # Dependencies of Diffusers noise sampler has been removed for clarity in training + + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +# temporary copied from sd3_minimal_inferece.py + + +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + else: + generator = None + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_all_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + # with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + mmdit.prepare_block_swap_before_forward() + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + mmdit.prepare_block_swap_before_forward() + return x + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + clean_memory_on_device(accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) + + indices = (u * (t_max - t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to flowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1, 1, 1, 1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 000000000..5fbaa4c3e --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,302 @@ +from dataclasses import dataclass +import math +import re +from typing import Dict, List, Optional, Union +import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models + +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + +# TODO remove dependency on flux_utils +from library.safetensors_utils import load_safetensors +from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl + + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") + for key in list(state_dict.keys()): + m = re_attn.search(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large, medium or 3-medium. This is added after `stable-diffusion-3-`. + if qk_norm is not None: + if len(x_block_self_attn_layers) == 0: + model_type = "5-large" + else: + model_type = "5-medium" + else: + model_type = "medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params + + +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) + with init_empty_weights(): + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + clip_l_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_l_sd = None + if clip_l_path is None: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_l_sd is None: + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_clip_g( + clip_g_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_g_sd = None + if state_dict is not None: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_g_sd is None: + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip + + +def load_t5xxl( + t5xxl_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + t5xxl_sd = None + if state_dict is not None: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None + + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) + + +def load_vae( + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap, dtype=vae_dtype) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE(vae_dtype, device) + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype + return vae + + +# endregion + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b182566..9196eb0f2 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ def __init__( vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ def __call__( max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] - - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, - ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights + ) + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 4fad78a1c..0466c1fa5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 000000000..3af45f4db --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a89..0aa07d0d6 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_ti self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ def call_module(module, h, emb, context): hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ def call_module(module, h, emb, context): # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ac9eb3b2..e559e7185 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -365,9 +364,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_cac # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if support_text_encoder_caching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: @@ -379,4 +378,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_cac def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 000000000..e88d273fc --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,637 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +import re +from typing import Any, List, Optional, Tuple, Union, Callable + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, + cache_to_disk: bool, + batch_size: Optional[int], + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + self._is_weighted = is_weighted + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + @property + def is_weighted(self): + return self._is_weighted + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def cache_suffix(self): + raise NotImplementedError + + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _default_is_disk_cached_latents_expected( + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + apply_alpha_mask: bool, + multi_resolution: bool = False, + ) -> bool: + """ + Args: + latents_stride: stride of latents + bucket_reso: resolution of the bucket + npz_path: path to the npz file + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + multi_resolution: whether to use multi-resolution latents + + Returns: + bool + """ + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + + try: + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, + encode_by_vae: Callable, + vae_device: torch.device, + vae_dtype: torch.dtype, + image_infos: List, + flip_aug: bool, + apply_alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + + Args: + encode_by_vae: function to encode images by VAE + vae_device: device to use for VAE + vae_dtype: dtype to use for VAE + image_infos: list of ImageInfo + flip_aug: whether to flip images + apply_alpha_mask: whether to apply alpha mask + random_crop: whether to random crop images + multi_resolution: whether to use multi-resolution latents + + Returns: + None + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, apply_alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + + if self.cache_to_disk: + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + for SD/SDXL + + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Args: + latents_stride (Optional[int]): Stride for latents. If None, load all latents. + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray] + ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask + """ + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", + ): + """ + Args: + npz_path (str): Path to the npz file. + latents_tensor (torch.Tensor): Latent tensor + original_size (List[int]): Original size of the image + crop_ltrb (List[int]): Crop left top right bottom + flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor + alpha_mask (Optional[torch.Tensor]): Alpha mask + key_reso_suffix (str): Key resolution suffix + + Returns: + None + """ + kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + # TODO float() is needed if vae is in bfloat16. Remove it if vae is float16. + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) + if flipped_latents_tensor is not None: + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 000000000..5e65927f8 --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,271 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import flux_utils, train_util +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: Optional[bool] = None, + ) -> List[torch.Tensor]: + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + # clip_l is None when using T5 only + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + # t5xxl is None when using CLIP only + if t5xxl is not None and t5_tokens is not None: + # t5_out is [b, max length, 4096] + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one + + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + self.warn_fp8_weights = False + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask + return [l_pooled, t5_out, txt_ids, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_hunyuan_image.py b/library/strategy_hunyuan_image.py new file mode 100644 index 000000000..5c704728f --- /dev/null +++ b/library/strategy_hunyuan_image.py @@ -0,0 +1,218 @@ +import os +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import AutoTokenizer, Qwen2Tokenizer + +from library import hunyuan_image_text_encoder, hunyuan_image_vae, train_util +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class HunyuanImageTokenizeStrategy(TokenizeStrategy): + def __init__(self, tokenizer_cache_dir: Optional[str] = None) -> None: + self.vlm_tokenizer = self._load_tokenizer( + Qwen2Tokenizer, hunyuan_image_text_encoder.QWEN_2_5_VL_IMAGE_ID, tokenizer_cache_dir=tokenizer_cache_dir + ) + self.byt5_tokenizer = self._load_tokenizer( + AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, subfolder="", tokenizer_cache_dir=tokenizer_cache_dir + ) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + vlm_tokens, vlm_mask = hunyuan_image_text_encoder.get_qwen_tokens(self.vlm_tokenizer, text) + + # byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text) + byt5_tokens = [] + byt5_mask = [] + for t in text: + tokens, mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, t) + if tokens is None: + tokens = torch.zeros((1, 1), dtype=torch.long) + mask = torch.zeros((1, 1), dtype=torch.long) + byt5_tokens.append(tokens) + byt5_mask.append(mask) + max_len = max([m.shape[1] for m in byt5_mask]) + byt5_tokens = torch.cat([torch.nn.functional.pad(t, (0, max_len - t.shape[1]), value=0) for t in byt5_tokens], dim=0) + byt5_mask = torch.cat([torch.nn.functional.pad(m, (0, max_len - m.shape[1]), value=0) for m in byt5_mask], dim=0) + + return [vlm_tokens, vlm_mask, byt5_tokens, byt5_mask] + + +class HunyuanImageTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + vlm_tokens, vlm_mask, byt5_tokens, byt5_mask = tokens + + qwen2vlm, byt5 = models + + # autocast and no_grad are handled in hunyuan_image_text_encoder + vlm_embed, vlm_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds_from_tokens(qwen2vlm, vlm_tokens, vlm_mask) + + # ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens( + # byt5, byt5_tokens, byt5_mask + # ) + ocr_mask, byt5_embed, byt5_updated_mask = [], [], [] + for i in range(byt5_tokens.shape[0]): + ocr_m, byt5_e, byt5_m = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens( + byt5, byt5_tokens[i : i + 1], byt5_mask[i : i + 1] + ) + ocr_mask.append(torch.zeros((1,), dtype=torch.long) + (1 if ocr_m[0] else 0)) # 1 or 0 + byt5_embed.append(byt5_e) + byt5_updated_mask.append(byt5_m) + + ocr_mask = torch.cat(ocr_mask, dim=0).to(torch.bool) # [B] + byt5_embed = torch.cat(byt5_embed, dim=0) + byt5_updated_mask = torch.cat(byt5_updated_mask, dim=0) + + return [vlm_embed, vlm_mask, byt5_embed, byt5_updated_mask, ocr_mask] + + +class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return ( + os.path.splitext(image_abs_path)[0] + + HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + ) + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "vlm_embed" not in npz: + return False + if "vlm_mask" not in npz: + return False + if "byt5_embed" not in npz: + return False + if "byt5_mask" not in npz: + return False + if "ocr_mask" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + vln_embed = data["vlm_embed"] + vlm_mask = data["vlm_mask"] + byt5_embed = data["byt5_embed"] + byt5_mask = data["byt5_mask"] + ocr_mask = data["ocr_mask"] + return [vln_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + huyuan_image_text_encoding_strategy: HunyuanImageTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = huyuan_image_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks + ) + + if vlm_embed.dtype == torch.bfloat16: + vlm_embed = vlm_embed.float() + if byt5_embed.dtype == torch.bfloat16: + byt5_embed = byt5_embed.float() + + vlm_embed = vlm_embed.cpu().numpy() + vlm_mask = vlm_mask.cpu().numpy() + byt5_embed = byt5_embed.cpu().numpy() + byt5_mask = byt5_mask.cpu().numpy() + ocr_mask = ocr_mask.cpu().numpy() + + for i, info in enumerate(infos): + vlm_embed_i = vlm_embed[i] + vlm_mask_i = vlm_mask[i] + byt5_embed_i = byt5_embed[i] + byt5_mask_i = byt5_mask[i] + ocr_mask_i = ocr_mask[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + vlm_embed=vlm_embed_i, + vlm_mask=vlm_mask_i, + byt5_embed=byt5_embed_i, + byt5_mask=byt5_mask_i, + ocr_mask=ocr_mask_i, + ) + else: + info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i) + + +class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy): + HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(32, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents( + self, vae: hunyuan_image_vae.HunyuanVAE2D, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + ): + # encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample() + def encode_by_vae(img_tensor): + # no_grad is handled in _default_cache_batch_latents + nonlocal vae + with torch.autocast(device_type=vae.device.type, dtype=vae.dtype): + return vae.encode(img_tensor).sample() + + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py new file mode 100644 index 000000000..964d9f7a4 --- /dev/null +++ b/library/strategy_lumina.py @@ -0,0 +1,375 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast +from library import train_util +from library.strategy_base import ( + LatentsCachingStrategy, + TokenizeStrategy, + TextEncodingStrategy, + TextEncoderOutputsCachingStrategy, +) +import numpy as np +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +GEMMA_ID = "google/gemma-2-2b" + + +class LuminaTokenizeStrategy(TokenizeStrategy): + def __init__( + self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + ) -> None: + self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( + GEMMA_ID, cache_dir=tokenizer_cache_dir + ) + self.tokenizer.padding_side = "right" + + if system_prompt is None: + system_prompt = "" + system_prompt_special_token = "" + system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else "" + self.system_prompt = system_prompt + + if max_length is None: + self.max_length = 256 + else: + self.max_length = max_length + + def tokenize( + self, text: Union[str, List[str]], is_negative: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + token input ids, attention_masks + """ + text = [text] if isinstance(text, str) else text + + # In training, we always add system prompt (is_negative=False) + if not is_negative: + # Add system prompt to the beginning of each text + text = [self.system_prompt + t for t in text] + + encodings = self.tokenizer( + text, + max_length=self.max_length, + return_tensors="pt", + padding="max_length", + truncation=True, + pad_to_multiple_of=8, + ) + return (encodings.input_ids, encodings.attention_mask) + + def tokenize_with_weights( + self, text: str | List[str] + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + text (Union[str, List[str]]): Text to tokenize + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + token input ids, attention_masks, weights + """ + # Gemma doesn't support weighted prompts, return uniform weights + tokens, attention_masks = self.tokenize(text) + weights = [torch.ones_like(t) for t in tokens] + return tokens, attention_masks, weights + + +class LuminaTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + super().__init__() + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ + text_encoder = models[0] + # Check model or torch dynamo OptimizedModule + assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}" + input_ids, attention_masks = tokens + + outputs = text_encoder( + input_ids=input_ids.to(text_encoder.device), + attention_mask=attention_masks.to(text_encoder.device), + output_hidden_states=True, + return_dict=True, + ) + + return outputs.hidden_states[-2], input_ids, attention_masks + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: Tuple[torch.Tensor, torch.Tensor], + weights: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks + weights_list (List[torch.Tensor]): Currently unused + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states, input_ids, attention_masks + """ + # For simplicity, use uniform weighting + return self.encode_tokens(tokenize_strategy, models, tokens) + + +class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + ) -> None: + super().__init__( + cache_to_disk, + batch_size, + skip_disk_cache_validity_check, + is_partial, + ) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return ( + os.path.splitext(image_abs_path)[0] + + LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + ) + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + """ + Args: + npz_path (str): Path to the npz file. + + Returns: + bool: True if the npz file is expected to be cached. + """ + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state" not in npz: + return False + if "attention_mask" not in npz: + return False + if "input_ids" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + """ + Load outputs from a npz file + + Returns: + List[np.ndarray]: hidden_state, input_ids, attention_mask + """ + data = np.load(npz_path) + hidden_state = data["hidden_state"] + attention_mask = data["attention_mask"] + input_ids = data["input_ids"] + return [hidden_state, input_ids, attention_mask] + + @torch.no_grad() + def cache_batch_outputs( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + text_encoding_strategy: TextEncodingStrategy, + batch: List[train_util.ImageInfo], + ) -> None: + """ + Args: + tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy + models (List[Any]): Text encoders + text_encoding_strategy (LuminaTextEncodingStrategy): + infos (List): List of ImageInfo + + Returns: + None + """ + assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy) + assert isinstance(tokenize_strategy, LuminaTokenizeStrategy) + + captions = [info.caption for info in batch] + + if self.is_weighted: + tokens, attention_masks, weights_list = ( + tokenize_strategy.tokenize_with_weights(captions) + ) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + models, + (tokens, attention_masks), + weights_list, + ) + ) + else: + tokens = tokenize_strategy.tokenize(captions) + hidden_state, input_ids, attention_masks = ( + text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens + ) + ) + + if hidden_state.dtype != torch.float32: + hidden_state = hidden_state.float() + + hidden_state = hidden_state.cpu().numpy() + attention_mask = attention_masks.cpu().numpy() # (B, S) + input_ids = input_ids.cpu().numpy() # (B, S) + + + for i, info in enumerate(batch): + hidden_state_i = hidden_state[i] + attention_mask_i = attention_mask[i] + input_ids_i = input_ids[i] + + if self.cache_to_disk: + assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}" + np.savez( + info.text_encoder_outputs_npz, + hidden_state=hidden_state_i, + attention_mask=attention_mask_i, + input_ids=input_ids_i, + ) + else: + info.text_encoder_outputs = [ + hidden_state_i, + input_ids_i, + attention_mask_i, + ] + + +class LuminaLatentsCachingStrategy(LatentsCachingStrategy): + LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path( + self, absolute_path: str, image_size: Tuple[int, int] + ) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected( + self, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + ) -> bool: + """ + Args: + bucket_reso (Tuple[int, int]): The resolution of the bucket. + npz_path (str): Path to the npz file. + flip_aug (bool): Whether to flip the image. + alpha_mask (bool): Whether to apply + """ + return self._default_is_disk_cached_latents_expected( + 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: + """ + Args: + npz_path (str): Path to the npz file. + bucket_reso (Tuple[int, int]): The resolution of the bucket. + + Returns: + Tuple[ + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], + Optional[np.ndarray], + ]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet + """ + return self._default_load_latents_from_disk( + 8, npz_path, bucket_reso + ) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents( + self, + model, + batch: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + ): + encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu") + vae_device = model.device + vae_dtype = model.dtype + + self._default_cache_batch_latents( + encode_by_vae, + vae_device, + vae_dtype, + batch, + flip_aug, + alpha_mask, + random_crop, + multi_resolution=True, + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(model.device) diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 000000000..a44fc4092 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,171 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + tokens = tokens.to(text_encoder.device) + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + @property + def cache_suffix(self) -> str: + return self.suffix + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 000000000..1d55fe21d --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,420 @@ +import os +import glob +import random +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, + ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ + clip_l, clip_g, t5xxl = models + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] + + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens + + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + + if l_tokens is None or clip_l is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] + + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) + + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is None or t5_tokens is None: + t5_out = None + t5_attn_mask = None + else: + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out + else: + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask + + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "lg_out" not in npz: + return False + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used + return False + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] + + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] + + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # always disable dropout during caching + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] + lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 000000000..6b3e2afa6 --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,306 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ] + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] + + # apply weights + if weights_list[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index fd46f905b..756d88b1c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -11,15 +12,8 @@ import re import shutil import time -from typing import ( - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, -) +import typing +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math @@ -30,10 +24,14 @@ from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed + from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -62,7 +60,7 @@ KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -71,11 +69,12 @@ import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image, validate_interpolation_fn setup_logging() import logging @@ -85,10 +84,6 @@ # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -141,6 +136,46 @@ ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + + +def split_train_val( + paths: List[str], + sizes: List[Optional[Tuple[int, int]]], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None, +) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:80] = 80 training images + [80:] = 20 validation images + """ + dataset = list(zip(paths, sizes)) + if validation_seed is not None: + logging.info(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(dataset) + random.setstate(prevstate) + else: + random.shuffle(dataset) + + paths, sizes = zip(*dataset) + paths = list(paths) + sizes = list(sizes) + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part + split = math.ceil(len(paths) * (1 - validation_split)) + return paths[0:split], sizes[0:split] + else: + # Validation dataset we split to the second part + split = len(paths) - round(len(paths) * validation_split) + return paths[split:], sizes[split:] class ImageInfo: @@ -153,19 +188,26 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + self.resize_interpolation: Optional[str] = None class BucketManager: @@ -387,6 +429,10 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -410,8 +456,15 @@ def __init__( self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + + self.resize_interpolation = resize_interpolation + class DreamBoothSubset(BaseSubset): def __init__( @@ -440,6 +493,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -464,6 +521,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.is_reg = is_reg @@ -503,6 +564,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -527,6 +592,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.metadata_file = metadata_file @@ -562,6 +631,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, + resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -586,6 +659,10 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, + resize_interpolation=resize_interpolation, ) self.conditioning_data_dir = conditioning_data_dir @@ -603,17 +680,13 @@ def __eq__(self, other) -> bool: class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -634,8 +707,6 @@ def __init__( self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -647,6 +718,12 @@ def __init__( self.image_transforms = IMAGE_TRANSFORMS + if resize_interpolation is not None: + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' + self.resize_interpolation = resize_interpolation + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -655,6 +732,15 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() + def adjust_min_max_bucket_reso_by_steps( self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int ) -> Tuple[int, int]: @@ -905,6 +991,23 @@ def make_buckets(self): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -974,22 +1077,6 @@ def make_buckets(self): for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1022,7 +1109,111 @@ def is_text_encoder_output_cacheable(self): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def new_cache_latents(self, model: Any, accelerator: Accelerator): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batch: List[ImageInfo] = [] + current_condition = None + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + + # remove image from memory + for info in batch: + info.image = None + + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) + + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue + + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] + + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1060,7 +1251,7 @@ def __eq__(self, other): # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue @@ -1097,17 +1288,110 @@ def __eq__(self, other): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs + if caching_strategy.cache_to_disk: + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + tokenize_strategy = TokenizeStrategy.get_strategy() + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1115,13 +1399,14 @@ def cache_text_encoder_outputs( for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1130,20 +1415,25 @@ def cache_text_encoder_outputs( return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -1152,18 +1442,47 @@ def cache_text_encoder_outputs( # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): if image_path.endswith(".jxl") or image_path.endswith(".JXL"): return get_jxl_size(image_path) - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use PIL as a fallback + try: + with Image.open(image_path) as img: + image_size = img.size + except Exception as e: + logger.warning(f"failed to get image size: {image_path}, error: {e}") + image_size = (0, 0) + return image_size def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) @@ -1199,7 +1518,7 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_ nh = int(height * scale + 0.5) nw = int(width * scale + 0.5) assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + image = resize_image(image, width, height, nw, nh, subset.resize_interpolation) face_cx = int(face_cx * scale + 0.5) face_cy = int(face_cy * scale + 0.5) height, width = nh, nw @@ -1241,7 +1560,6 @@ def __getitem__(self, index): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] alpha_mask_list = [] images = [] @@ -1249,16 +1567,17 @@ def __getitem__(self, index): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + custom_attributes.append(subset.custom_attributes) + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1275,7 +1594,9 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) + ) if flipped: latents = flipped_latents alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem @@ -1294,7 +1615,11 @@ def __getitem__(self, index): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + subset.random_crop, + img, + image_info.bucket_reso, + image_info.resized_size, + resize_interpolation=image_info.resize_interpolation, ) else: if face_cx > 0: # 顔位置情報あり @@ -1364,74 +1689,101 @@ def __getitem__(self, index): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) + + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) + + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + + # old implementation without padding: all elements must have same length + # return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + + # new implementation with padding support + result = [] + for i in range(len(tensors_list[0])): + tensors = [x[i] for x in tensors_list] + if tensors[0].ndim == 0: + # scalar value: e.g. ocr mask + result.append(torch.stack([converter(x[i]) for x in tensors_list])) + continue - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + min_len = min([len(x) for x in tensors]) + max_len = max([len(x) for x in tensors]) + + if min_len == max_len: + # no padding + result.append(torch.stack([converter(x) for x in tensors])) + else: + # padding + tensors = [converter(x) for x in tensors] + if tensors[0].ndim == 1: + # input_ids or mask + result.append( + torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]) + ) + else: + # text encoder outputs + result.append( + torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]) + ) + return result + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones none_or_not = [x is None for x in alpha_mask_list] @@ -1541,12 +1893,14 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1556,8 +1910,11 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + validation_split: float, + validation_seed: Optional[int], + resize_interpolation: Optional[str], ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1565,6 +1922,9 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1630,12 +1990,69 @@ def load_dreambooth_dir(subset: DreamBoothSubset): with open(info_cache_file, "r", encoding="utf-8") as f: metas = json.load(f) img_paths = list(metas.keys()) - sizes = [meta["resolution"] for meta in metas.values()] + sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - sizes = [None] * len(img_paths) + sizes: List[Optional[Tuple[int, int]]] = [None] * len(img_paths) + + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix + npz_path_index = 0 + + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + + if w is not None and h is not None: + sizes[i] = (w, h) + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset + if self.validation_split > 0.0: + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + sizes = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths, sizes = split_train_val( + img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed + ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") @@ -1646,7 +2063,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( @@ -1695,9 +2112,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: - if subset.num_repeats < 1: + num_repeats = subset.num_repeats if self.is_training_dataset else 1 + if num_repeats < 1: logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}" ) continue @@ -1715,12 +2133,15 @@ def load_dreambooth_dir(subset: DreamBoothSubset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + num_reg_images += num_repeats * len(img_paths) else: - num_train_images += subset.num_repeats * len(img_paths) + num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) if size is not None: info.image_size = size if subset.is_reg: @@ -1731,10 +2152,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") + self.num_train_images = num_train_images - logger.info(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images with repeats.") if num_train_images < num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") @@ -1764,8 +2187,6 @@ def __init__( self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1774,8 +2195,11 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + validation_seed: int, + validation_split: float, + resize_interpolation: Optional[str], ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) self.batch_size = batch_size @@ -1992,8 +2416,6 @@ def __init__( self, subsets: Sequence[ControlNetSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -2001,9 +2423,12 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset: float, + debug_dataset: bool, + validation_split: float, + validation_seed: Optional[int], + resize_interpolation: Optional[str] = None, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) db_subsets = [] for subset in subsets: @@ -2035,14 +2460,14 @@ def __init__( subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, + resize_interpolation=subset.resize_interpolation, ) db_subsets.append(db_subset) self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + True, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -2052,6 +2477,9 @@ def __init__( bucket_no_upscale, 1.0, debug_dataset, + validation_split, + validation_seed, + resize_interpolation, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2059,6 +2487,9 @@ def __init__( self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.validation_split = validation_split + self.validation_seed = validation_seed + self.resize_interpolation = resize_interpolation # assert all conditioning data exists missing_imgs = [] @@ -2102,6 +2533,9 @@ def __init__( self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2110,6 +2544,12 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2137,9 +2577,15 @@ def __getitem__(self, index): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA - ) # INTER_AREAでやりたいのでcv2でリサイズ + + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2153,7 +2599,14 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2193,14 +2646,27 @@ def add_replacement(self, str_from, str_to): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + + def new_cache_latents(self, model: Any, accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2209,6 +2675,21 @@ def cache_text_encoder_outputs( logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size + ) + + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2217,12 +2698,19 @@ def verify_bucket_reso_steps(self, min_steps: int): for dataset in self.datasets: dataset.verify_bucket_reso_steps(min_steps) + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2275,34 +2763,35 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask - - -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# TODO update to use CachingStrategy +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2329,12 +2818,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2345,12 +2834,14 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") - - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") + + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2419,8 +2910,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2481,8 +2972,11 @@ def __getitem__(self, idx): """ raise NotImplementedError + def get_resolutions(self) -> List[Tuple[int, int]]: + return [] + -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) dataset_class = args.dataset_class.split(".")[-1] module = importlib.import_module(module) @@ -2509,17 +3003,13 @@ def load_image(image_path, alpha=False): # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - if image_width > resized_size[0] and image_height > resized_size[1]: - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - else: - image = pil_resize(image, resized_size) + image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation) image_height, image_width = image.shape[0:2] @@ -2543,6 +3033,53 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: @@ -2560,7 +3097,9 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -2599,14 +3138,15 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk( - info.latents_npz, - latent, - info.latents_original_size, - info.latents_crop_ltrb, - flipped_latent, - alpha_mask, - ) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: @@ -2649,6 +3189,34 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2942,7 +3510,7 @@ def load_metadata_from_safetensors(safetensors_file: str) -> dict: def build_minimum_network_metadata( - v2: Optional[bool], + v2: Optional[str], base_model: Optional[str], network_module: str, network_dim: str, @@ -2971,6 +3539,10 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, + flux: str = None, # "dev", "schnell" or "chroma" + lumina: str = None, + optional_metadata: dict[str, str] | None = None, ): timestamp = time.time() @@ -2987,6 +3559,34 @@ def get_sai_model_spec( else: timesteps = None + # Convert individual model parameters to model_config dict + # TODO: Update calls to this function to pass in the model config + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + + # Extract metadata_* fields from args and merge with optional_metadata + extracted_metadata = {} + + # Extract all metadata_* attributes from args + for attr_name in dir(args): + if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): + value = getattr(args, attr_name, None) + if value is not None: + # Remove metadata_ prefix and exclude already handled fields + field_name = attr_name[9:] # len("metadata_") = 9 + if field_name not in ["title", "author", "description", "license", "tags"]: + extracted_metadata[field_name] = value + + # Merge extracted metadata with provided optional_metadata + all_optional_metadata = {**extracted_metadata} + if optional_metadata: + all_optional_metadata.update(optional_metadata) + metadata = sai_model_spec.build_metadata( state_dict, v2, @@ -3004,10 +3604,78 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + model_config=model_config, + optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata +def get_sai_model_spec_dataclass( + state_dict: dict, + args: argparse.Namespace, + sdxl: bool, + lora: bool, + textual_inversion: bool, + is_stable_diffusion_ckpt: Optional[bool] = None, + sd3: str = None, + flux: str = None, + lumina: str = None, + hunyuan_image: str = None, + optional_metadata: dict[str, str] | None = None, +) -> sai_model_spec.ModelSpecMetadata: + """ + Get ModelSpec metadata as a dataclass - preferred for new code. + Automatically extracts metadata_* fields from args. + """ + timestamp = time.time() + + v2 = args.v2 + v_parameterization = args.v_parameterization + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + # Convert individual model parameters to model_config dict + model_config = {} + if sd3 is not None: + model_config["sd3"] = sd3 + if flux is not None: + model_config["flux"] = flux + if lumina is not None: + model_config["lumina"] = lumina + if hunyuan_image is not None: + model_config["hunyuan_image"] = hunyuan_image + + # Use the dataclass function directly + return sai_model_spec.build_metadata_dataclass( + state_dict, + v2, + v_parameterization, + sdxl, + lora, + textual_inversion, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, + model_config=model_config, + optional_metadata=optional_metadata, + ) + + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models parser.add_argument( @@ -3084,6 +3752,20 @@ def int_or_float(value): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -3128,8 +3810,8 @@ def int_or_float(value): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" - + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能", ) parser.add_argument( "--lr_scheduler_timescale", @@ -3276,7 +3958,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -3327,11 +4022,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度", ) - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument( - "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" + "--full_fp16", + action="store_true", + help="fp16 training including gradients, some models are not supported / 勾配も含めてfp16で学習する、一部のモデルではサポートされていません", + ) + parser.add_argument( + "--full_bf16", + action="store_true", + help="bf16 training including gradients, some models are not supported / 勾配も含めてbf16で学習する、一部のモデルではサポートされていません", ) # TODO move to SDXL training, because it is not supported by SD1/2 - parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( + "--fp8_base", + action="store_true", + help="use fp8 for base model, some models are not supported / base modelにfp8を使う、一部のモデルではサポートされていません", + ) parser.add_argument( "--ddp_timeout", @@ -3466,8 +4171,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -3481,7 +4186,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( @@ -3551,57 +4265,90 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" ) + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" + ) + - # SAI Model spec +def add_masked_loss_arguments(parser: argparse.ArgumentParser): parser.add_argument( - "--metadata_title", + "--conditioning_data_dir", type=str, default=None, - help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + help="conditioning data directory / 条件付けデータのディレクトリ", ) parser.add_argument( - "--metadata_author", - type=str, - default=None, - help="author name for model metadata / メタデータに書き込まれるモデル作者名", + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) + + +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments parser.add_argument( - "--metadata_description", - type=str, - default=None, - help="description for model metadata / メタデータに書き込まれるモデル説明", + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) parser.add_argument( - "--metadata_license", - type=str, - default=None, - help="license for model metadata / メタデータに書き込まれるモデルライセンス", + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) parser.add_argument( - "--metadata_tags", - type=str, + "--text_encoder_batch_size", + type=int, default=None, - help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", ) - if support_dreambooth: - # DreamBooth training - parser.add_argument( - "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" - ) - + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) -def add_masked_loss_arguments(parser: argparse.ArgumentParser): + # Training arguments. partial copy from Diffusers parser.add_argument( - "--conditioning_data_dir", + "--weighting_scheme", type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", ) parser.add_argument( - "--masked_loss", - action="store_true", - help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) @@ -3695,15 +4442,19 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -3863,6 +4614,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -3893,7 +4650,13 @@ def add_dataset_arguments( action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) - + parser.add_argument( + "--resize_interpolation", + type=str, + default=None, + choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"], + help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area", + ) parser.add_argument( "--token_warmup_min", type=int, @@ -4035,7 +4798,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -4098,7 +4860,7 @@ def task(): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type @@ -4373,27 +5135,167 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + + if "." not in case_sensitive_optimizer_type: # from torch.optim optimizer_module = torch.optim - else: - values = optimizer_type.split(".") + else: # from other library + values = case_sensitive_optimizer_type.split(".") optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = values[-1] - optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. + optimizer.train() + return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4402,6 +5304,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_warmup_steps: Optional[int] = ( @@ -4561,33 +5467,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -4627,8 +5506,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, + kwargs_handlers = [ + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), ( DistributedDataParallelKwargs( gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph @@ -4636,8 +5525,8 @@ def prepare_accelerator(args: argparse.Namespace): if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( @@ -4740,6 +5629,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): + + from accelerate import DistributedType + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + return + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -5202,32 +6097,18 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) - return timesteps, huber_c + return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): +def get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents: torch.FloatTensor +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -5245,8 +6126,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5259,16 +6139,48 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + # This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() + + return noise, noisy_latents, timesteps + + +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): - + """ + NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already + """ if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5276,6 +6188,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5283,7 +6197,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss @@ -5293,6 +6207,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) @@ -5405,6 +6320,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["scale"] = float(m.group(1)) continue + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # guidance scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict["negative_prompt"] = m.group(1) @@ -5420,6 +6340,21 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["controlnet_image"] = m.group(1) continue + m = re.match(r"ctr (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["cfg_trunc_ratio"] = float(m.group(1)) + continue + + m = re.match(r"rcfg (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["renorm_cfg"] = float(m.group(1)) + continue + + m = re.match(r"fs (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["flow_shift"] = m.group(1) + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) @@ -5427,22 +6362,54 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, text_encoder, - unet, + unet_wrapped, prompt_replacement=None, controlnet=None, ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: @@ -5471,7 +6438,7 @@ def sample_images_common( vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) - unet = accelerator.unwrap_model(unet) + unet = accelerator.unwrap_model(unet_wrapped) if isinstance(text_encoder, (list, tuple)): text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] else: @@ -5490,11 +6457,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5555,21 +6518,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, @@ -5624,7 +6584,7 @@ def sample_image_inference( logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") - with accelerator.autocast(): + with accelerator.autocast(), torch.no_grad(): latents = pipeline( prompt=prompt, height=height, @@ -5652,17 +6612,42 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + + wandb_tracker.define_metric("global_step", hidden=True) # endregion @@ -5733,4 +6718,7 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + losses = len(self.loss_list) + if losses == 0: + return 0 + return self.loss_total / losses diff --git a/library/utils.py b/library/utils.py index 49d46a546..296fc4151 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,9 +1,11 @@ import logging import sys import threading +from typing import * + import torch +import torch.nn as nn from torchvision import transforms -from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput @@ -16,6 +18,9 @@ def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +# region Logging + + def add_logging_arguments(parser): parser.add_argument( "--console_log_level", @@ -82,7 +87,114 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +setup_logging() +logger = logging.getLogger(__name__) + +# endregion + +# region PyTorch utils + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + +# endregion + +# region Image utils + + +def pil_resize(image, size, interpolation): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: @@ -90,7 +202,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): else: pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - resized_pil = pil_image.resize(size, interpolation) + resized_pil = pil_image.resize(size, resample=interpolation) # Convert back to cv2 format if has_alpha: @@ -101,9 +213,130 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 -# TODO make inf_utils.py +def resize_image( + image: np.ndarray, + width: int, + height: int, + resized_width: int, + resized_height: int, + resize_interpolation: Optional[str] = None, +): + """ + Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. + + Args: + image: numpy.ndarray + width: int Original image width + height: int Original image height + resized_width: int Resized image width + resized_height: int Resized image height + resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box" + + Returns: + image + """ + + # Ensure all size parameters are actual integers + width = int(width) + height = int(height) + resized_width = int(resized_width) + resized_height = int(resized_height) + + if resize_interpolation is None: + if width >= resized_width and height >= resized_height: + resize_interpolation = "area" + else: + resize_interpolation = "lanczos" + + # we use PIL for lanczos (for backward compatibility) and box, cv2 for others + use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"] + + resized_size = (resized_width, resized_height) + if use_pil: + interpolation = get_pil_interpolation(resize_interpolation) + image = pil_resize(image, resized_size, interpolation=interpolation) + logger.debug(f"resize image using {resize_interpolation} (PIL)") + else: + interpolation = get_cv2_interpolation(resize_interpolation) + image = cv2.resize(image, resized_size, interpolation=interpolation) + logger.debug(f"resize image using {resize_interpolation} (cv2)") + + return image + + +def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: + """ + Convert interpolation value to cv2 interpolation integer + + https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 + """ + if interpolation is None: + return None + + if interpolation == "lanczos" or interpolation == "lanczos4": + # Lanczos interpolation over 8x8 neighborhood + return cv2.INTER_LANCZOS4 + elif interpolation == "nearest": + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + return cv2.INTER_NEAREST_EXACT + elif interpolation == "bilinear" or interpolation == "linear": + # bilinear interpolation + return cv2.INTER_LINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # bicubic interpolation + return cv2.INTER_CUBIC + elif interpolation == "area": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + elif interpolation == "box": + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + return cv2.INTER_AREA + else: + return None + + +def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: + """ + Convert interpolation value to PIL interpolation + + https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters + """ + if interpolation is None: + return None + + if interpolation == "lanczos": + return Image.Resampling.LANCZOS + elif interpolation == "nearest": + # Pick one nearest pixel from the input image. Ignore all other input pixels. + return Image.Resampling.NEAREST + elif interpolation == "bilinear" or interpolation == "linear": + # For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used. + return Image.Resampling.BILINEAR + elif interpolation == "bicubic" or interpolation == "cubic": + # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. + return Image.Resampling.BICUBIC + elif interpolation == "area": + # Image.Resampling.BOX may be more appropriate if upscaling + # Area interpolation is related to cv2.INTER_AREA + # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. + return Image.Resampling.HAMMING + elif interpolation == "box": + # Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST. + return Image.Resampling.BOX + else: + return None +def validate_interpolation_fn(interpolation_str: str) -> bool: + """ + Check if a interpolation function is supported + """ + return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + + +# endregion + +# TODO make inf_utils.py # region Gradual Latent hires fix @@ -234,7 +467,9 @@ def step( elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: - raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) sigma_from = self.sigmas[self.step_index] sigma_to = self.sigmas[self.step_index + 1] diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py new file mode 100644 index 000000000..47d6d30b9 --- /dev/null +++ b/lumina_minimal_inference.py @@ -0,0 +1,418 @@ +# Minimum Inference Code for Lumina +# Based on flux_minimal_inference.py + +import logging +import argparse +import math +import os +import random +import time +from typing import Optional + +import einops +import numpy as np +import torch +from accelerate import Accelerator +from PIL import Image +from safetensors.torch import load_file +from tqdm import tqdm +from transformers import Gemma2Model +from library.flux_models import AutoEncoder + +from library import ( + device_utils, + lumina_models, + lumina_train_util, + lumina_util, + sd3_train_utils, + strategy_lumina, +) +import networks.lora_lumina as lora_lumina +from library.device_utils import get_preferred_device, init_ipex +from library.utils import setup_logging, str_to_dtype + +init_ipex() +setup_logging() +logger = logging.getLogger(__name__) + + +def generate_image( + model: lumina_models.NextDiT, + gemma2: Gemma2Model, + ae: AutoEncoder, + prompt: str, + system_prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: int, + guidance_scale: float, + negative_prompt: Optional[str], + args: argparse.Namespace, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, +): + # + # 0. Prepare arguments + # + device = get_preferred_device() + if args.device: + device = torch.device(args.device) + + dtype = str_to_dtype(args.dtype) + ae_dtype = str_to_dtype(args.ae_dtype) + gemma2_dtype = str_to_dtype(args.gemma2_dtype) + + # + # 1. Prepare models + # + # model.to(device, dtype=dtype) + model.to(dtype) + model.eval() + + gemma2.to(device, dtype=gemma2_dtype) + gemma2.eval() + + ae.to(ae_dtype) + ae.eval() + + # + # 2. Encode prompts + # + logger.info("Encoding prompts...") + + tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length) + encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy() + + tokens_and_masks = tokenize_strategy.tokenize(prompt) + with torch.no_grad(): + gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) + + tokens_and_masks = tokenize_strategy.tokenize( + negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt + ) + with torch.no_grad(): + neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) + + # Unpack Gemma2 outputs + prompt_hidden_states, _, prompt_attention_mask = gemma2_conds + uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds + + if args.offload: + print("Offloading models to CPU to save VRAM...") + gemma2.to("cpu") + device_utils.clean_memory() + + model.to(device) + + # + # 3. Prepare latents + # + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + torch.manual_seed(seed) + + latent_height = image_height // 8 + latent_width = image_width // 8 + latent_channels = 16 + + latents = torch.randn( + (1, latent_channels, latent_height, latent_width), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # + # 4. Denoise + # + logger.info("Denoising...") + scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + scheduler.set_timesteps(steps, device=device) + timesteps = scheduler.timesteps + + # # compare with lumina_train_util.retrieve_timesteps + # lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps) + # print(f"Using timesteps: {timesteps}") + # print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same + + with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad(): + latents = lumina_train_util.denoise( + scheduler, + model, + latents.to(device), + prompt_hidden_states.to(device), + prompt_attention_mask.to(device), + uncond_hidden_states.to(device), + uncond_attention_mask.to(device), + timesteps, + guidance_scale, + cfg_trunc_ratio, + renorm_cfg, + ) + + if args.offload: + model.to("cpu") + device_utils.clean_memory() + ae.to(device) + + # + # 5. Decode latents + # + logger.info("Decoding image...") + # latents = latents / ae.scale_factor + ae.shift_factor + with torch.no_grad(): + image = ae.decode(latents.to(ae_dtype)) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + image = (image * 255).round().astype("uint8") + + # + # 6. Save image + # + pil_image = Image.fromarray(image[0]) + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + seed_suffix = f"_{seed}" + output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png") + pil_image.save(output_path) + logger.info(f"Image saved to {output_path}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Lumina DiT model path / Lumina DiTモデルのパス", + ) + parser.add_argument( + "--gemma2_path", + type=str, + default=None, + required=True, + help="Gemma2 model path / Gemma2モデルのパス", + ) + parser.add_argument( + "--ae_path", + type=str, + default=None, + required=True, + help="Autoencoder model path / Autoencoderモデルのパス", + ) + parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation") + parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty") + parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--steps", type=int, default=36, help="Number of inference steps") + parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance") + parser.add_argument("--image_width", type=int, default=1024, help="Image width") + parser.add_argument("--image_height", type=int, default=1024, help="Image height") + parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)") + parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)") + parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)") + parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')") + parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM") + parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model") + parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt") + parser.add_argument( + "--gemma2_max_token_length", + type=int, + default=256, + help="Max token length for Gemma2 tokenizer", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=6.0, + help="Shift value for FlowMatchEulerDiscreteScheduler", + ) + parser.add_argument( + "--cfg_trunc_ratio", + type=float, + default=0.25, + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.", + ) + parser.add_argument( + "--renorm_cfg", + type=float, + default=1.0, + help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.", + ) + parser.add_argument( + "--use_flash_attn", + action="store_true", + help="Use flash attention for Lumina model", + ) + parser.add_argument( + "--use_sage_attn", + action="store_true", + help="Use sage attention for Lumina model", + ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument( + "--interactive", + action="store_true", + help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + + logger.info("Loading models...") + device = get_preferred_device() + if args.device: + device = torch.device(args.device) + + # Load Lumina DiT model + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + dtype=None, # Load in fp32 and then convert + device="cpu", + use_flash_attn=args.use_flash_attn, + use_sage_attn=args.use_sage_attn, + ) + + # Load Gemma2 + gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu") + + # Load Autoencoder + ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu") + + # LoRA + lora_models = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([gemma2], model, weights_sd) + else: + lora_model.apply_to([gemma2], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.to(device) + lora_model.set_multiplier(multiplier) + lora_model.eval() + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + model, + gemma2, + ae, + args.prompt, + args.system_prompt, + args.seed, + args.image_width, + args.image_height, + args.steps, + args.guidance_scale, + args.negative_prompt, + args, + args.cfg_trunc_ratio, + args.renorm_cfg, + ) + else: + # Interactive mode loop + image_width = args.image_width + image_height = args.image_height + steps = args.steps + guidance_scale = args.guidance_scale + cfg_trunc_ratio = args.cfg_trunc_ratio + renorm_cfg = args.renorm_cfg + + print("Entering interactive mode.") + while True: + print( + "\nEnter prompt (or 'exit'). Options: --w --h --s --d --g --n --ctr --rcfg --m " + ) + user_input = input() + if user_input.lower() == "exit": + break + if not user_input: + continue + + # Parse options + options = user_input.split("--") + prompt = options[0].strip() + + # Set defaults for each generation + seed = None # New random seed each time unless specified + negative_prompt = args.negative_prompt # Reset to default + + for opt in options[1:]: + try: + opt = opt.strip() + if not opt: + continue + + key, value = (opt.split(None, 1) + [""])[:2] + + if key == "w": + image_width = int(value) + elif key == "h": + image_height = int(value) + elif key == "s": + steps = int(value) + elif key == "d": + seed = int(value) + elif key == "g": + guidance_scale = float(value) + elif key == "n": + negative_prompt = value if value != "-" else "" + elif key == "ctr": + cfg_trunc_ratio = float(value) + elif key == "rcfg": + renorm_cfg = float(value) + elif key == "m": + multipliers = value.split(",") + if len(multipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(multipliers[i].strip())) + else: + logger.warning(f"Unknown option: --{key}") + + except (ValueError, IndexError) as e: + logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}") + + generate_image( + model, + gemma2, + ae, + prompt, + args.system_prompt, + seed, + image_width, + image_height, + steps, + guidance_scale, + negative_prompt, + args, + cfg_trunc_ratio, + renorm_cfg, + ) + + logger.info("Done.") diff --git a/lumina_train.py b/lumina_train.py new file mode 100644 index 000000000..ca60c6582 --- /dev/null +++ b/lumina_train.py @@ -0,0 +1,957 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import ( + deepspeed_utils, + lumina_train_util, + lumina_util, + strategy_base, + strategy_lumina, + sai_model_spec +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + # assert ( + # args.blocks_to_swap is None or args.blocks_to_swap == 0 + # ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, args.masked_loss, True) + ) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = ( + config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + ) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = ( + train_dataset_group if args.max_data_loader_n_workers == 0 else None + ) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + False, + ) + ) + strategy_base.TokenizeStrategy.set_strategy( + strategy_lumina.LuminaTokenizeStrategy(args.system_prompt) + ) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = lumina_util.load_ae( + args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.gemma2_max_token_length is None: + gemma2_max_token_length = 256 + else: + gemma2_max_token_length = args.gemma2_max_token_length + + lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy( + args.system_prompt, gemma2_max_token_length + ) + strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy) + + # load gemma2 for caching text encoder outputs + gemma2 = lumina_util.load_gemma2( + args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + gemma2.eval() + gemma2.requires_grad_(False) + + text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + gemma2.to(accelerator.device) + + text_encoder_caching_strategy = ( + strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + text_encoder_caching_strategy + ) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info( + f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}" + ) + + text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = ( + strategy_base.TextEncodingStrategy.get_strategy() + ) + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for i, p in enumerate([ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ]): + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt + sample_prompts_te_outputs[p] = ( + text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + tokens_and_masks, + ) + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + gemma2 = None + clean_memory_on_device(accelerator.device) + + # load lumina + nextdit = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + weight_dtype, + torch.device("cpu"), + disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, + ) + + if args.gradient_checkpointing: + nextdit.enable_gradient_checkpointing( + cpu_offload=args.cpu_offload_checkpointing + ) + + nextdit.requires_grad_(True) + + # block swap + + # backward compatibility + # if args.blocks_to_swap is None: + # blocks_to_swap = args.double_blocks_to_swap or 0 + # if args.single_blocks_to_swap is not None: + # blocks_to_swap += args.single_blocks_to_swap // 2 + # if blocks_to_swap > 0: + # logger.warning( + # "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + # " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + # ) + # logger.info( + # f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + # ) + # args.blocks_to_swap = blocks_to_swap + # del blocks_to_swap + + # is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + # if is_swapping_blocks: + # # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # # This idea is based on 2kpr's great work. Thank you! + # logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + # flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(nextdit) + name_and_params = list(nextdit.named_parameters()) + # single param group for now + params_to_optimize.append( + {"params": [p for _, p in name_and_params], "lr": args.learning_rate} + ) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(nextdit.named_parameters()) + assert len(named_parameters) == len( + group["params"] + ), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info( + f"using {len(optimizers)} optimizers for blockwise fused optimizers" + ) + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError( + "Schedule-free optimizer is not supported with blockwise fused optimizers" + ) + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer( + args, trainable_params=params_to_optimize + ) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn( + optimizer, args + ) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min( + args.max_data_loader_n_workers, os.cpu_count() + ) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) + / accelerator.num_processes + / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [ + train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + for optimizer in optimizers + ] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + nextdit.to(weight_dtype) + if gemma2 is not None: + gemma2.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + gemma2.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + nextdit = accelerator.prepare( + nextdit, device_placement=[not is_swapping_blocks] + ) + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks( + accelerator.device + ) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook( + create_grad_hook(param_name, param_group) + ) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_( + parameter, args.max_grad_norm + ) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = ( + math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + ) + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print( + f" num examples / サンプル数: {train_dataset_group.num_train_images}" + ) + accelerator.print( + f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}" + ) + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + accelerator.print( + f" total optimization steps / 学習ステップ数: {args.max_train_steps}" + ) + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=args.discrete_flow_shift + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + 0, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = { + i: 0 for i in range(len(optimizers)) + } # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to( + accelerator.device, dtype=weight_dtype + ) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ + ids.to(accelerator.device) + for ids in batch["input_ids_list"] + ] + text_encoder_conds = text_encoding_strategy.encode_tokens( + lumina_tokenize_strategy, + [gemma2], + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [ + c.to(weight_dtype) for c in text_encoder_conds + ] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = ( + lumina_train_util.get_noisy_model_input_and_timesteps( + args, + noise_scheduler_copy, + latents, + noise, + accelerator.device, + weight_dtype, + ) + ) + # call model + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = nextdit( + x=noisy_model_input, # image latents (B, C, H, W) + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask.to( + dtype=torch.int32 + ), # Gemma2的attention mask + ) + # apply model prediction type + model_pred, weighting = lumina_train_util.apply_model_prediction_type( + args, model_pred, noisy_model_input, sigmas + ) + + # flow matching loss + target = latents - noise + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed( + args, timesteps, noise_scheduler + ) + loss = train_util.conditional_loss( + model_pred.float(), target.float(), args.loss_type, "none", huber_c + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ( + "alpha_masks" in batch and batch["alpha_masks"] is not None + ): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + lumina_train_util.sample_images( + accelerator, + args, + None, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + + # 指定ステップごとにモデルを保存 + if ( + args.save_every_n_steps is not None + and global_step % args.save_every_n_steps == 0 + ): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs( + logs, lr_scheduler, args.optimizer_type, including_unet=True + ) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(nextdit), + ) + + lumina_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + nextdit, + ae, + gemma2, + sample_prompts_te_outputs, + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + nextdit = accelerator.unwrap_model(nextdit) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + lumina_train_util.save_lumina_model_on_train_end( + args, save_dtype, epoch, global_step, nextdit + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + sai_model_spec.add_model_spec_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/lumina_train_network.py b/lumina_train_network.py new file mode 100644 index 000000000..b08e31432 --- /dev/null +++ b/lumina_train_network.py @@ -0,0 +1,383 @@ +import argparse +import copy +from typing import Any, Tuple + +import torch + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from torch import Tensor +from accelerate import Accelerator + + +import train_network +from library import ( + lumina_models, + lumina_util, + lumina_train_util, + sd3_train_utils, + strategy_base, + strategy_lumina, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LuminaNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group, val_dataset_group): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning("Enabling cache_text_encoder_outputs due to disk caching") + args.cache_text_encoder_outputs = True + + train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) + + self.train_gemma2 = not args.network_train_unet_only + + def load_target_model(self, args, weight_dtype, accelerator): + loading_dtype = None if args.fp8_base else weight_dtype + + model = lumina_util.load_lumina_model( + args.pretrained_model_name_or_path, + loading_dtype, + torch.device("cpu"), + disable_mmap=args.disable_mmap_load_safetensors, + use_flash_attn=args.use_flash_attn, + use_sage_attn=args.use_sage_attn, + ) + + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 Lumina 2 model") + else: + logger.info( + "Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + if args.blocks_to_swap: + logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + self.is_swapping_blocks = True + + gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu") + gemma2.eval() + ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu") + + return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model + + def get_tokenize_strategy(self, args): + return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + + def get_text_encoding_strategy(self, args): + return strategy_lumina.LuminaTextEncodingStrategy() + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_gemma2] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_gemma2, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, + args, + accelerator: Accelerator, + unet, + vae, + text_encoders, + dataset, + weight_dtype, + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + + if text_encoders[0].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[0].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}") + + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) + assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) + + sample_prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in sample_prompts: + prompts = [ + prompt_dict.get("prompt", ""), + prompt_dict.get("negative_prompt", ""), + ] + for i, prompt in enumerate(prompts): + if prompt in sample_prompts_te_outputs: + continue + + logger.info(f"cache Text Encoder outputs for prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt + sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + ) + + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move Gemma 2 back to cpu") + text_encoders[0].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + + def sample_images( + self, + accelerator, + args, + epoch, + global_step, + device, + vae, + tokenizer, + text_encoder, + lumina, + ): + lumina_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + lumina, + vae, + self.get_models_for_text_encoding(args, accelerator, text_encoder), + self.sample_prompts_te_outputs, + ) + + # Remaining methods maintain similar structure to flux implementation + # with Lumina-specific model calls and strategies + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, vae, images): + return vae.encode(images) + + # not sure, they use same flux vae + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator: Accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks) + dit: lumina_models.NextDiT, + network, + weight_dtype, + train_unet, + is_train=True, + ): + assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler) + noise = torch.randn_like(latents) + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Unpack Gemma2 outputs + gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds + + def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps): + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # NextDiT forward expects (x, t, cap_feats, cap_mask) + model_pred = dit( + x=img, # image latents (B, C, H, W) + t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features + cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask + ) + return model_pred + + model_pred = call_dit( + img=noisy_model_input, + gemma2_hidden_states=gemma2_hidden_states, + gemma2_attn_mask=gemma2_attn_mask, + timesteps=timesteps, + ) + + # apply model prediction type + model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss + target = latents - noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=noisy_model_input[diff_output_pr_indices], + gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]), + ) + network.set_multiplier(1.0) + + # model_pred_prior = lumina_util.unpack_latents( + # model_pred_prior, packed_latent_height, packed_latent_width + # ) + model_pred_prior, _ = lumina_train_util.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2") + + def update_metadata(self, metadata, args): + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + text_encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.embed_tokens.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + nextdit = unet + assert isinstance(nextdit, lumina_models.NextDiT) + nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward() + + return nextdit + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + lumina_train_util.add_lumina_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = LuminaNetworkTrainer() + trainer.train(args) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py new file mode 100644 index 000000000..fe6466ebc --- /dev/null +++ b/networks/convert_flux_lora.py @@ -0,0 +1,434 @@ +# convert key mapping and data format from some LoRA format to another +""" +Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module +alpha is scalar for each LoRA module + +0 to 18 +lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4]) +lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4]) + +0 to 37 +lora_unet_single_blocks_0_linear1.alpha torch.Size([]) +lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4]) +lora_unet_single_blocks_0_linear2.alpha torch.Size([]) +lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360]) +lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4]) +lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([]) +lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4]) +""" +""" +ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules. +A is down, B is up. No alpha for each LoRA module. + +0 to 18 +transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16]) +transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16]) + +0 to 37 +transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16]) +transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16]) +transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360]) +transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16]) +""" +""" +xlabs: Unknown format. +0 to 18 +double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16]) +double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16]) +""" + + +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key): + ait_down_key = ait_key + ".lora_A.weight" + if ait_down_key not in ait_sd: + return + ait_up_key = ait_key + ".lora_B.weight" + + down_weight = ait_sd.pop(ait_down_key) + sds_sd[sds_key + ".lora_down.weight"] = down_weight + sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key) + rank = down_weight.shape[0] + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device) + + +def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys): + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + if ait_down_keys[0] not in ait_sd: + return + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + down_weights = [ait_sd.pop(k) for k in ait_down_keys] + up_weights = [ait_sd.pop(k) for k in ait_up_keys] + + # lora_down is concatenated along dim=0, so rank is multiplied by the number of splits + rank = down_weights[0].shape[0] + num_splits = len(ait_keys) + sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0) + + merged_up_weights = torch.zeros( + (sum(w.shape[0] for w in up_weights), rank * num_splits), + dtype=up_weights[0].dtype, + device=up_weights[0].device, + ) + + i = 0 + for j, up_weight in enumerate(up_weights): + merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight + i += up_weight.shape[0] + + sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights + + # set alpha to new_rank + new_rank = rank * num_splits + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device) + + +def convert_ai_toolkit_to_sd_scripts(ait_sd): + sds_sd = {} + for i in range(19): + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(ait_sd) > 0: + logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}") + return sds_sd + + +def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + # print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}") + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + +def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + +def convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + return ait_sd + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting {args.src} to {args.dst} format") + if args.src == "ai-toolkit" and args.dst == "sd-scripts": + state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) + elif args.src == "sd-scripts" and args.dst == "ai-toolkit": + state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() + else: + raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts") + parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts") + parser.add_argument("--src_path", type=str, default=None, help="source path") + parser.add_argument("--dst_path", type=str, default=None, help="destination path") + args = parser.parse_args() + main(args) diff --git a/networks/convert_hunyuan_image_lora_to_comfy.py b/networks/convert_hunyuan_image_lora_to_comfy.py new file mode 100644 index 000000000..df12897df --- /dev/null +++ b/networks/convert_hunyuan_image_lora_to_comfy.py @@ -0,0 +1,88 @@ +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library import train_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting...") + + # Key mapping tables: (sd-scripts format, ComfyUI format) + double_blocks_mappings = [ + ("img_mlp_fc1", "img_mlp_0"), + ("img_mlp_fc2", "img_mlp_2"), + ("img_mod_linear", "img_mod_lin"), + ("txt_mlp_fc1", "txt_mlp_0"), + ("txt_mlp_fc2", "txt_mlp_2"), + ("txt_mod_linear", "txt_mod_lin"), + ] + + single_blocks_mappings = [ + ("modulation_linear", "modulation_lin"), + ] + + keys = list(state_dict.keys()) + count = 0 + + for k in keys: + new_k = k + + if "double_blocks" in k: + mappings = double_blocks_mappings + elif "single_blocks" in k: + mappings = single_blocks_mappings + else: + continue + + # Apply mappings based on conversion direction + for src_key, dst_key in mappings: + if args.reverse: + # ComfyUI to sd-scripts: swap src and dst + new_k = new_k.replace(dst_key, src_key) + else: + # sd-scripts to ComfyUI: use as-is + new_k = new_k.replace(src_key, dst_key) + + if new_k != k: + state_dict[new_k] = state_dict.pop(k) + count += 1 + # print(f"Renamed {k} to {new_k}") + + logger.info(f"Converted {count} keys") + + # Calculate hash + if metadata is not None: + logger.info(f"Calculating hashes and creating metadata...") + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("src_path", type=str, default=None, help="source path, sd-scripts format") + parser.add_argument("dst_path", type=str, default=None, help="destination path, ComfyUI format") + parser.add_argument("--reverse", action="store_true", help="reverse conversion direction") + args = parser.parse_args() + main(args) diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py new file mode 100644 index 000000000..f1ae8f965 --- /dev/null +++ b/networks/flux_extract_lora.py @@ -0,0 +1,220 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from library import flux_utils, sai_model_spec +from library.safetensors_utils import MemoryEfficientSafeOpen +from library.utils import setup_logging +from networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + store_device = "cpu" + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as f_org: + # filter keys + keys = [] + for key in f_org.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + + with open_fn(model_tuned) as f_tuned: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = f_org.get_tensor(key) + value_t = f_tuned.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + # print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + lora_sd, False, False, False, True, False, time.time(), title, model_config={"flux": "dev"} + ) + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py new file mode 100644 index 000000000..45ff67497 --- /dev/null +++ b/networks/flux_merge_lora.py @@ -0,0 +1,784 @@ +import argparse +import math +import os +import time +from typing import Any, Dict, Union + +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype +from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import lora_flux as lora_flux +from library import sai_model_spec, train_util + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): + if dtype is not None: + logger.info(f"converting to {dtype}...") + for key in tqdm(list(state_dict.keys())): + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: + state_dict[key] = state_dict[key].to(dtype) + + logger.info(f"saving to: {file_name}") + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) + + +def merge_to_flux_model( + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, +): + # create module map without loading state_dict + lora_name_to_module_key = {} + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} + if mem_eff_load_save: + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) + else: + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) + + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict + + +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = qkv_img + key_map[f"{k}to_k.{end}"] = qkv_img + key_map[f"{k}to_v.{end}"] = qkv_img + key_map[f"{k}add_q_proj.{end}"] = qkv_txt + key_map[f"{k}add_k_proj.{end}"] = qkv_txt + key_map[f"{k}add_v_proj.{end}"] = qkv_txt + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = qkv + key_map[f"{k}to_k.{end}"] = qkv + key_map[f"{k}to_v.{end}"] = qkv + key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + # add as-is keys + values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())]) + values.sort() + key_map.update({v: v for v in values}) + + return key_map + + key_map = create_key_map(18, 38) # 18 double layers, 38 single layers + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + "proj_mlp": "linear1", + } + + # same key exists in both single_block_map and double_block_map, so we must care about single/double + # print("lora_key before double_block_map", lora_key) + for old, new in double_block_map.items(): + if "double" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key before single_block_map", lora_key) + for old, new in single_block_map.items(): + if "single" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key after mapping", lora_key) + + if lora_key in key_map: + flux_key = key_map[lora_key] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) + + logger.info("merging...") + for key in lora_sd.keys(): + if "lora_down" in key or "lora_A" in key: + lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha" + + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") + continue + + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + weight = flux_state_dict[flux_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # print(up_weight.size(), down_weight.size(), weight.size()) + + if lora_name.startswith("transformer."): + if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp + update = ratio * (up_weight @ down_weight) * scale + # print(update.shape) + + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=0) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=0) + elif "linear1" in flux_key: + q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0) + mlp = weight[int(update.shape[-1] * 3) :] + # print(q.shape, k.shape, v.shape, mlp.shape) + if "to_q" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name: + v += update.reshape(v.shape) + elif "proj_mlp" in lora_name: + mlp += update.reshape(mlp.shape) + weight = torch.cat([q, k, v, mlp], dim=0) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) + del up_weight + del down_weight + del weight + + logger.info(f"Merged keys: {sorted(list(merged_keys))}") + return flux_state_dict + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info("merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + + assert len(args.models) == len( + args.ratios + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) + if not os.path.exists(dest_dir): + logger.info(f"creating directory: {dest_dir}") + os.makedirs(dest_dir) + + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: + if not args.diffusers: + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( + args.loading_device, + args.working_device, + args.flux_model, + args.clip_l, + args.t5xxl, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, + ) + else: + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, + ) + clip_l_state_dict = None + t5xxl_state_dict = None + + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, + False, + False, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + model_config={"flux": "dev"}, + ) + + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) + + else: + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info("calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + flux_state_dict, + False, + False, + False, + True, + False, + time.time(), + title=title, + merged_from=merged_from, + model_config={"flux": "dev"}, + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--flux_model", + type=str, + default=None, + help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", + ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--loading_device", + type=str, + default="cpu", + help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます", + ) + parser.add_argument( + "--working_device", + type=str, + default="cpu", + help="device to work (merge). Merging LoRA models are done on CPU." + + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--models", + type=str, + nargs="*", + help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", + ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 000000000..d74d01728 --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,1448 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +from torch import Tensor +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.ggpo_sigma = ggpo_sigma + self.ggpo_beta = ggpo_beta + + if self.ggpo_beta is not None and self.ggpo_sigma is not None: + self.combined_weight_norms = None + self.grad_norms = None + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self.initialize_norm_cache(org_module.weight) + self.org_module_shape: tuple[int] = org_module.weight.shape + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + # LoRA Gradient-Guided Perturbation Optimization + if ( + self.training + and self.ggpo_sigma is not None + and self.ggpo_beta is not None + and self.combined_weight_norms is not None + and self.grad_norms is not None + ): + with torch.no_grad(): + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + ( + self.ggpo_beta * (self.grad_norms**2) + ) + perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation.mul_(perturbation_scale_factor) + perturbation_output = x @ perturbation.T # Result: (batch × n) + return org_forwarded + (self.multiplier * scale * lx) + perturbation_output + else: + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + @torch.no_grad() + def initialize_norm_cache(self, org_module_weight: Tensor): + # Choose a reasonable sample size + n_rows = org_module_weight.shape[0] + sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller + + # Sample random indices across all rows + indices = torch.randperm(n_rows)[:sample_size] + + # Convert to a supported data type first, then index + # Use float32 for indexing operations + weights_float32 = org_module_weight.to(dtype=torch.float32) + sampled_weights = weights_float32[indices].to(device=self.device) + + # Calculate sampled norms + sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) + + # Store the mean norm as our estimate + self.org_weight_norm_estimate = sampled_norms.mean() + + # Optional: store standard deviation for confidence intervals + self.org_weight_norm_std = sampled_norms.std() + + # Free memory + del sampled_weights, weights_float32 + + @torch.no_grad() + def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True): + # Calculate the true norm (this will be slow but it's just for validation) + true_norms = [] + chunk_size = 1024 # Process in chunks to avoid OOM + + for i in range(0, org_module_weight.shape[0], chunk_size): + end_idx = min(i + chunk_size, org_module_weight.shape[0]) + chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) + chunk_norms = torch.norm(chunk, dim=1, keepdim=True) + true_norms.append(chunk_norms.cpu()) + del chunk + + true_norms = torch.cat(true_norms, dim=0) + true_mean_norm = true_norms.mean().item() + + # Compare with our estimate + estimated_norm = self.org_weight_norm_estimate.item() + + # Calculate error metrics + absolute_error = abs(true_mean_norm - estimated_norm) + relative_error = absolute_error / true_mean_norm * 100 # as percentage + + if verbose: + logger.info(f"True mean norm: {true_mean_norm:.6f}") + logger.info(f"Estimated norm: {estimated_norm:.6f}") + logger.info(f"Absolute error: {absolute_error:.6f}") + logger.info(f"Relative error: {relative_error:.2f}%") + + return { + "true_mean_norm": true_mean_norm, + "estimated_norm": estimated_norm, + "absolute_error": absolute_error, + "relative_error": relative_error, + } + + @torch.no_grad() + def update_norms(self): + # Not running GGPO so not currently running update norms + if self.ggpo_beta is None or self.ggpo_sigma is None: + return + + # only update norms when we are training + if self.training is False: + return + + module_weights = self.lora_up.weight @ self.lora_down.weight + module_weights.mul(self.scale) + + self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) + self.combined_weight_norms = torch.sqrt( + (self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True) + ) + + @torch.no_grad() + def update_grad_norms(self): + if self.training is False: + print(f"skipping update_grad_norms for {self.lora_name}") + return + + lora_down_grad = None + lora_up_grad = None + + for name, param in self.named_parameters(): + if name == "lora_down.weight": + lora_down_grad = param.grad + elif name == "lora_up.weight": + lora_up_grad = param.grad + + # Calculate gradient norms if we have both gradients + if lora_down_grad is not None and lora_up_grad is not None: + with torch.autocast(self.device.type): + approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) + self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + ggpo_beta = kwargs.get("ggpo_beta", None) + ggpo_sigma = kwargs.get("ggpo_sigma", None) + + if ggpo_beta is not None: + ggpo_beta = float(ggpo_beta) + + if ggpo_sigma is not None: + ggpo_sigma = float(ggpo_sigma) + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # regex-specific learning rates + def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: + """ + Parse a string of key-value pairs separated by commas. + """ + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + # parse regular expression based learning rates + network_reg_lrs = kwargs.get("network_reg_lrs", None) + if network_reg_lrs is not None: + reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False) + else: + reg_lrs = None + + # regex-specific dimensions (ranks) + network_reg_dims = kwargs.get("network_reg_dims", None) + if network_reg_dims is not None: + reg_dims = parse_kv_pairs(network_reg_dims, is_int=True) + else: + reg_dims = None + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, + reg_dims=reg_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + @classmethod + def get_qkv_mlp_split_dims(cls) -> List[int]: + return [3072] * 3 + [12288] + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, + reg_dims: Optional[Dict[str, int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, + reg_lrs: Optional[Dict[str, float]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + if ggpo_beta is not None and ggpo_sigma is not None: + logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.search(reg, lora_name): + dim = d + alpha = self.alpha + logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") + break + + # if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default) + if dim is None and modules_dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_flux and type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + (split_dims,) = self.get_qkv_mlp_split_dims()[:3] # qkv only + elif "single" in lora_name and "linear1" in lora_name: + split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if text_encoder is None: + logger.info(f"Text Encoder {index+1} is None, skipping LoRA creation for this encoder.") + continue + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def update_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_norms() + + def update_grad_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_grad_norms() + + def grad_norms(self) -> Tensor | None: + grad_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "grad_norms") and lora.grad_norms is not None: + grad_norms.append(lora.grad_norms.mean(dim=0)) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None + + def weight_norms(self) -> Tensor | None: + weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "weight_norms") and lora.weight_norms is not None: + weight_norms.append(lora.weight_norms.mean(dim=0)) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None + + def combined_weight_norms(self) -> Tensor | None: + combined_weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: + combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only + elif "single" in key and "linear1" in key: + split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only + elif "single" in key and "linear1" in key: + split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + # regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...} + reg_groups = {} + + for lora in loras: + # check if this lora matches any regex learning rate + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + try: + if re.search(regex_str, lora.lora_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + except re.error: + # regex error should have been caught during parsing, but just in case + continue + + for name, param in lora.named_parameters(): + param_key = f"{lora.lora_name}.{name}" + is_plus = loraplus_ratio is not None and "lora_up" in name + + if matched_reg_lr is not None: + # use regex-specific learning rate + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + + if is_plus: + reg_groups[group_key]["plus"][param_key] = param + else: + reg_groups[group_key]["lora"][param_key] = param + else: + # use default learning rate + if is_plus: + param_groups["plus"][param_key] = param + else: + param_groups["lora"][param_key] = param + + params = [] + descriptions = [] + + # process regex-specific groups first (higher priority) + for group_key in sorted(reg_groups.keys()): + group = reg_groups[group_key] + reg_lr = group["lr"] + + for param_type in ["lora", "plus"]: + if len(group[param_type]) == 0: + continue + + param_data = {"params": group[param_type].values()} + + if param_type == "plus" and loraplus_ratio is not None: + param_data["lr"] = reg_lr * loraplus_ratio + else: + param_data["lr"] = reg_lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + if param_type == "plus": + desc += " plus" + descriptions.append(desc) + + # process default groups + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_hunyuan_image.py b/networks/lora_hunyuan_image.py new file mode 100644 index 000000000..3e801f950 --- /dev/null +++ b/networks/lora_hunyuan_image.py @@ -0,0 +1,378 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import os +from typing import Dict, List, Optional, Type, Union +import torch +import torch.nn as nn +from torch import Tensor +import re + +from networks import lora_flux +from library.hunyuan_image_vae import HunyuanVAE2D + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 20 +NUM_SINGLE_BLOCKS = 40 + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: HunyuanVAE2D, + text_encoders: List[nn.Module], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + ggpo_beta = kwargs.get("ggpo_beta", None) + ggpo_sigma = kwargs.get("ggpo_sigma", None) + + if ggpo_beta is not None: + ggpo_beta = float(ggpo_beta) + + if ggpo_sigma is not None: + ggpo_sigma = float(ggpo_sigma) + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # regex-specific learning rates + def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: + """ + Parse a string of key-value pairs separated by commas. + """ + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + # parse regular expression based learning rates + network_reg_lrs = kwargs.get("network_reg_lrs", None) + if network_reg_lrs is not None: + reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False) + else: + reg_lrs = None + + # regex-specific dimensions (ranks) + network_reg_dims = kwargs.get("network_reg_dims", None) + if network_reg_dims is not None: + reg_dims = parse_kv_pairs(network_reg_dims, is_int=True) + else: + reg_dims = None + + # Too many arguments ( ^ω^)・・・ + network = HunyuanImageLoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + reg_dims=reg_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = lora_flux.LoRAInfModule if for_inference else lora_flux.LoRAModule + + network = HunyuanImageLoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + ) + return network, weights_sd + + +class HunyuanImageLoRANetwork(lora_flux.LoRANetwork): + TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"] + TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"] + LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible + + @classmethod + def get_qkv_mlp_split_dims(cls) -> List[int]: + return [3584] * 3 + [14336] + + def __init__( + self, + text_encoders: list[nn.Module], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = lora_flux.LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + reg_dims: Optional[Dict[str, int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, + reg_lrs: Optional[Dict[str, float]] = None, + verbose: Optional[bool] = False, + ) -> None: + nn.Module.__init__(self) + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + if ggpo_beta is not None and ggpo_sigma is not None: + logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + + if self.split_qkv: + logger.info(f"split qkv for LoRA") + + # create module instances + def create_modules( + is_dit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[lora_flux.LoRAModule]: + assert is_dit, "only DIT is supported now" + + prefix = self.LORA_PREFIX_HUNYUAN_IMAGE_DIT + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.search(reg, lora_name): + dim = d + alpha = self.alpha + logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") + break + + # if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default) + if dim is None and modules_dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_dit and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only + elif "single" in lora_name and "linear1" in lora_name: + split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for U-Net + target_replace_modules = ( + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE + ) + + self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + self.text_encoder_loras = [] + + logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py new file mode 100644 index 000000000..0929e8390 --- /dev/null +++ b/networks/lora_lumina.py @@ -0,0 +1,1038 @@ +# temporary minimum implementation of LoRA +# Lumina 2 does not have Conv2d, so ignore +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from transformers import CLIPTextModel +import torch +from torch import Tensor, nn +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name: str, + org_module: nn.Module, + multiplier: float =1.0, + lora_dim: int = 4, + alpha: Optional[float | int | Tensor] = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + split_dims: Optional[List[int]] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of lumina as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + assert isinstance(in_dim, int) + assert isinstance(out_dim, int) + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False) + + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = nn.ModuleList( + [nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + + for lora_down in self.lora_down: + nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + nn.init.zeros_(lora_up.weight) + + if isinstance(alpha, Tensor): + alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lxs[i].size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lxs[i].size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + lumina, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim for JointTransformerBlock + attn_dim = kwargs.get("attn_dim", None) # attention dimension + mlp_dim = kwargs.get("mlp_dim", None) # MLP dimension + mod_dim = kwargs.get("mod_dim", None) # modulation dimension + refiner_dim = kwargs.get("refiner_dim", None) # refiner blocks dimension + + if attn_dim is not None: + attn_dim = int(attn_dim) + if mlp_dim is not None: + mlp_dim = int(mlp_dim) + if mod_dim is not None: + mod_dim = int(mod_dim) + if refiner_dim is not None: + refiner_dim = int(refiner_dim) + + type_dims = [attn_dim, mlp_dim, mod_dim, refiner_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # embedder_dims for embedders + embedder_dims = kwargs.get("embedder_dims", None) + if embedder_dims is not None: + embedder_dims = embedder_dims.strip() + if embedder_dims.startswith("[") and embedder_dims.endswith("]"): + embedder_dims = embedder_dims[1:-1] + embedder_dims = [int(d) for d in embedder_dims.split(",")] + assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 3 dimensions (x_embedder, t_embedder, cap_embedder)" + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner" + if train_blocks is not None: + assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + type_dims=type_dims, + embedder_dims=embedder_dims, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + lumina, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"] + LORA_PREFIX_LUMINA = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder + + def __init__( + self, + text_encoders, # Now this will be a single Gemma2 model + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[LoRAModule] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + type_dims: Optional[List[int]] = None, + embedder_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + + self.type_dims = type_dims + self.embedder_dims = embedder_dims + + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.embedder_dims = [0] * 5 # create embedder_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + + # create module instances + def create_modules( + is_lumina: bool, + root_module: torch.nn.Module, + target_replace_modules: Optional[List[str]], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # for handling embedders + module = root_module + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + # Only Linear is supported + if not is_linear: + skipped.append(lora_name) + continue + + if filter is not None and filter not in lora_name: + continue + + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + # Set dim/alpha to modules dim/alpha + if modules_dim is not None and modules_alpha is not None: + # network from weights + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + dim = 0 # skip if not found + + else: + # Set dims to type_dims + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + # Drop blocks if we are only training some blocks + if ( + is_lumina + and dim + and ( + self.train_block_indices is not None + ) + and ("layer" in lora_name) + ): + # "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..." + block_index = int(lora_name.split("_")[3]) # bit dirty + if ( + "layer" in lora_name + and self.train_block_indices is not None + and not self.train_block_indices[block_index] + ): + dim = 0 + + + if dim is None or dim == 0: + # skipした情報を出力 + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder (Gemma2) + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + + logger.info(f"create LoRA for Gemma2 Text Encoder:") + text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + # TODO: limit different blocks + elif self.train_blocks == "transformer": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "refiners": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "noise_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "cap_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) + + # Handle embedders + if self.embedder_dims: + for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims): + loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # # split qkv + # for key in list(state_dict.keys()): + # if "double" in key and "qkv" in key: + # split_dims = [3072] * 3 + # elif "single" in key and "linear1" in key: + # split_dims = [3072] * 3 + [12288] + # else: + # continue + + # weight = state_dict[key] + # lora_name = key.split(".")[0] + + # if key not in state_dict: + # continue # already merged + + # # (rank, in_dim) * 3 + # down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # # (split dim, rank) * 3 + # up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + # alpha = state_dict.pop(f"{lora_name}.alpha") + + # # merge down weight + # down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # # merge up weight (sum of split_dim, rank*3) + # rank = up_weights[0].size(1) + # up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + # i = 0 + # for j in range(len(split_dims)): + # up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + # i += split_dims[j] + + # state_dict[f"{lora_name}.lora_down.weight"] = down_weight + # state_dict[f"{lora_name}.lora_up.weight"] = up_weight + # state_dict[f"{lora_name}.alpha"] = alpha + + # # print( + # # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # # ) + # print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + # merge qkv + state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_LUMINA): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te_loras = [lora for lora in self.text_encoder_loras] + if len(te_loras) > 0: + logger.info(f"Text Encoder: {len(te_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 000000000..ce6d1a16f --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,839 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import numpy as np +import torch +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from networks.lora_flux import LoRAModule, LoRAInfModule +from library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/oft.py b/networks/oft.py index 6321def3b..0c3a5393f 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ def __init__( alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 000000000..27b8b637a --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..34b7e9c1f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning + ignore::FutureWarning +pythonpath = . diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 000000000..1ba145634 --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,4 @@ +# dummy module for pytorch_lightning + +class ModelCheckpoint: + pass diff --git a/requirements.txt b/requirements.txt index e6e1bf6fc..624978b49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,29 @@ -accelerate==0.30.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 -ftfy==6.1.1 +accelerate==1.6.0 +transformers==4.54.1 +diffusers[torch]==0.32.1 +ftfy==6.3.1 # albumentations==1.3.0 -opencv-python==4.8.1.78 +opencv-python==4.10.0.84 einops==0.7.0 -pytorch-lightning==1.9.0 -bitsandbytes==0.44.0 -prodigyopt==1.0 -lion-pytorch==0.0.6 +# pytorch-lightning==1.9.0 +bitsandbytes +lion-pytorch==0.2.3 +schedulefree==1.4 +pytorch-optimizer==3.7.0 +prodigy-plus-schedule-free==1.9.2 +prodigyopt==1.1.2 tensorboard -safetensors==0.4.2 +safetensors==0.4.5 # gradio==3.16.2 -altair==4.2.2 -easygui==0.98.3 +# altair==4.2.2 +# easygui==0.98.3 toml==0.10.2 -voluptuous==0.13.1 -huggingface-hub==0.24.5 +voluptuous==0.15.2 +huggingface-hub==0.34.3 # for Image utils imagesize==1.4.1 +numpy +# <=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -37,6 +42,8 @@ imagesize==1.4.1 # open clip for SDXL # open-clip-torch==2.20.0 # For logging -rich==13.7.0 +rich==14.1.0 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.1 # for kohya_ss library -e . diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 000000000..d7b97a59f --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,407 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +import torch.amp +from tqdm import tqdm +from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel + +from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils, strategy_sd3 +from library.safetensors_utils import load_safetensors + + +def get_noise(seed, latent, device="cpu"): + # generator = torch.manual_seed(seed) + generator = torch.Generator(device) + generator.manual_seed(seed) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + cfg_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent, device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * cfg_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + latent = vae.process_out(latent) + return latent + + +def generate_image( + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: CLIPTextModelWithProjection, + clip_g: CLIPTextModelWithProjection, + t5xxl: T5EncoderModel, + steps: int, + prompt: str, + seed: int, + target_width: int, + target_height: int, + device: str, + negative_prompt: str, + cfg_scale: float, +): + # prepare embeddings + logger.info("Encoding prompts...") + + # TODO support one-by-one offloading + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): + tokens_and_masks = tokenize_strategy.tokenize(prompt) + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # attn masks are not used currently + + if args.offload: + clip_l.to("cpu") + clip_g.to("cpu") + t5xxl.to("cpu") + + # generate image + logger.info("Generating image...") + mmdit.to(device) + latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device) + if args.offload: + mmdit.to("cpu") + + # latent to image + vae.to(device) + with torch.no_grad(): + image = vae.decode(latent_sampled) + + if args.offload: + vae.to("cpu") + + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + # cfg_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument("--output_dir", type=str, default=".") + # parser.add_argument("--do_not_use_t5xxl", action="store_true") + # parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + loading_device = "cpu" if args.offload else device + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + # state_dict = load_file(args.ckpt_path) + state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype) + + # load text encoders + clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict) + + # MMDiT and VAE + vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict) + mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device) + + clip_l.to(sd3_dtype) + clip_g.to(sd3_dtype) + t5xxl.to(sd3_dtype) + vae.to(sd3_dtype) + mmdit.to(sd3_dtype) + if not args.offload: + # make sure to move to the device: some tensors are created in the constructor on the CPU + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + vae.to(device) + mmdit.to(device) + + clip_l.eval() + clip_g.eval() + t5xxl.eval() + mmdit.eval() + vae.eval() + + # load tokenizers + logger.info("Loading tokenizers...") + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + args.steps, + args.prompt, + args.seed, + args.width, + args.height, + device, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = args.width + height = args.height + steps = None + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d " + " --n , `--n -` for empty negative prompt" + "Options are kept for the next prompt. Current options:" + f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}" + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + steps if steps is not None else args.steps, + prompt, + seed if seed is not None else args.seed, + width, + height, + device, + negative_prompt if negative_prompt is not None else args.negative_prompt, + cfg_scale, + ) + + logger.info("Done!") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 000000000..c6a2fdd8d --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,1078 @@ +# training with captions + +import argparse +from concurrent.futures import ThreadPoolExecutor +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library import utils +from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import load_safetensors + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 + +import library.sai_model_spec as sai_model_spec +from library.sdxl_train_util import match_mixed_precision + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.train_t5xxl: + assert ( + args.train_text_encoder + ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります" + assert ( + not args.cache_text_encoder_outputs + ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません" + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # t5xxl_dtype = weight_dtype + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + if args.clip_l is None: + sd3_state_dict = load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + else: + sd3_state_dict = None + + # load tokenizer and prepare tokenize strategy + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) + + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + # clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" + + # prepare text encoding strategy + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate + ) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_clip = False + train_t5xxl = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + if args.train_t5xxl: + t5xxl.gradient_checkpointing_enable() + + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + lr_t5xxl = args.learning_rate_te3 if args.learning_rate_te3 is not None else args.learning_rate # 0 means not train + train_clip = lr_te1 != 0 or lr_te2 != 0 + train_t5xxl = lr_t5xxl != 0 and args.train_t5xxl + + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(train_clip) + clip_g.requires_grad_(train_clip) + t5xxl.requires_grad_(train_t5xxl) + else: + print("disable text encoder training") + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + t5xxl.requires_grad_(False) + lr_te1 = 0 + lr_te2 = 0 + lr_t5xxl = 0 + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + t5xxl.to(accelerator.device) + clip_l.eval() + clip_g.eval() + t5xxl.eval() + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + enable_dropout=False, + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + if not args.use_t5xxl_cache_only: + clip_l = None + clip_g = None + t5xxl = None + + clean_memory_on_device(accelerator.device) + + # load VAE for caching latents + if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") + sd3_state_dict = load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + + vae = sd3_utils.load_vae(args.vae, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + if cache_latents: + # vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # load MMDIT + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") + + # attn_mode = "xformers" if args.xformers else "torch" + # assert ( + # attn_mode == "torch" + # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + resolutions = train_dataset_group.get_resolutions() + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # move to accelerator device + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + + training_models = [] + params_to_optimize = [] + param_names = [] + training_models.append(mmdit) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) + param_names.append([n for n, _ in mmdit.named_parameters()]) + + if train_clip: + if lr_te1 > 0: + training_models.append(clip_l) + params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + param_names.append([n for n, _ in clip_l.named_parameters()]) + if lr_te2 > 0: + training_models.append(clip_g) + params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in clip_g.named_parameters()]) + if train_t5xxl: + training_models.append(t5xxl) + params_to_optimize.append({"params": list(t5xxl.parameters()), "lr": args.learning_rate_te3 or args.learning_rate}) + param_names.append([n for n, _ in t5xxl.named_parameters()]) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit} , clip:{train_clip}, t5xxl:{train_t5xxl}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups for mmdit. clip_l, clip_g, t5xxl are in each group + grouped_params = [] + param_group = {} + group = params_to_optimize[0] + named_parameters = list(mmdit.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # joint or other + if np[0].startswith("joint_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "joint" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + grouped_params.extend(params_to_optimize[1:]) # add clip_l, clip_g, t5xxl if they are trained + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) + + # move Text Encoders to GPU if not caching outputs + if not args.cache_text_encoder_outputs: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, mmdit=mmdit, clip_l=clip_l if train_clip else None, clip_g=clip_g if train_clip else None + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + if train_clip: + clip_l = accelerator.prepare(clip_l) + clip_g = accelerator.prepare(clip_g) + if train_t5xxl: + t5xxl = accelerator.prepare(t5xxl) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # show model device and dtype + logger.info( + f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}" + if mmdit + else "mmdit is None" + ) + logger.info( + f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}" + if clip_l + else "clip_l is None" + ) + logger.info( + f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}" + if clip_g + else "clip_g is None" + ) + logger.info( + f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}" + if t5xxl + else "t5xxl is None" + ) + logger.info( + f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}" + if vae is not None + else "vae is None" + ) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None + + if lg_out is None: + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] + with torch.set_grad_enabled(train_clip): + # TODO support weighted captions + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") + lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], + ) + + if t5_out is None: + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] + with torch.set_grad_enabled(train_t5xxl): + input_ids_t5xxl = input_ids_t5xxl.to("cpu") + _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] + ) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + # bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(lg_pooled)): + accelerator.print("NaN found in pool, replacing with zeros") + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # # Compute regular loss. TODO simplify this + # loss = torch.mean( + # (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # 1, + # ) + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if weighting is not None: + loss = loss * weighting + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, + vae, + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, + vae, + ) + + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument( + "--learning_rate_te3", + type=float, + default=None, + help="learning rate for text encoder 3 (T5-XXL) / text encoder 3 (T5-XXL)の学習率", + ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="[DOES NOT WORK] number of optimizer groups for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizerグループ数", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 000000000..c9b06a38a --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,497 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional, Union + +import torch +from accelerate import Accelerator +from library import sd3_models, strategy_sd3, utils +from library.device_utils import init_ipex, clean_memory_on_device +from library.safetensors_utils import load_safetensors + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): + # super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + # enumerate resolutions from dataset for positional embeddings + resolutions = train_dataset_group.get_resolutions() + if val_dataset_group is not None: + resolutions = resolutions + val_dataset_group.get_resolutions() + self.resolutions = resolutions + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5_dropout_rate, + ) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sd3_train_utils.sample_images( + accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) + return noise_scheduler + + def encode_images_to_latents(self, args, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return sd3_models.SDVAE.process_in(latents) + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + is_train=True, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with torch.set_grad_enabled(is_train), accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Sd3NetworkTrainer() + trainer.train(args) diff --git a/sdxl_train.py b/sdxl_train.py index b533b2749..f454263a4 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec import library.train_util as train_util @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,11 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +272,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -276,6 +287,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -307,16 +321,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -403,7 +418,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -597,8 +616,11 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -628,57 +650,39 @@ def optimizer_hook(parameter: torch.Tensor): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -692,9 +696,7 @@ def optimizer_hook(parameter: torch.Tensor): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -708,6 +710,7 @@ def optimizer_hook(parameter: torch.Tensor): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -716,9 +719,7 @@ def optimizer_hook(parameter: torch.Tensor): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -734,9 +735,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) @@ -769,7 +768,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -799,7 +798,7 @@ def optimizer_hook(parameter: torch.Tensor): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -816,7 +815,7 @@ def optimizer_hook(parameter: torch.Tensor): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -851,7 +850,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -894,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py new file mode 100644 index 000000000..3d107e57c --- /dev/null +++ b/sdxl_train_control_net.py @@ -0,0 +1,721 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from accelerate import init_empty_weights +from diffusers import DDPMScheduler +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, + sai_model_spec +) + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + unet.to(accelerator.device) # reduce main memory usage + + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd + + # make control net + logger.info("make ControlNet") + if args.controlnet_model_name_or_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() + + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + control_net.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + control_net.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + unet.eval() + control_net.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + control_net.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(control_net): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] + with torch.no_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] + ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + + with accelerator.autocast(): + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image + ) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = control_net.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, unwrap_model(control_net)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, unwrap_model(control_net)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # end of epoch + + if is_main_process: + control_net = unwrap_model(control_net) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, control_net, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + # train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + # train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet modules / controlnetモジュールの学習率", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0e67cde5c..4dd4b8d94 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -23,7 +23,17 @@ import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, + sai_model_spec, +) import library.model_util as model_util import library.train_util as train_util @@ -79,7 +89,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,8 +123,8 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -164,30 +181,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -242,7 +259,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -290,7 +311,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): - unet._set_static_graph() # avoid error for multiple use of the parameter + unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる @@ -357,7 +378,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -409,27 +432,25 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -443,9 +464,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -464,9 +483,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -520,14 +538,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -572,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 4a01f9e2c..0a9f4a92f 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -23,6 +24,7 @@ import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -102,7 +104,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -324,7 +326,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -406,7 +410,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +430,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -480,14 +485,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -532,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) deepspeed_utils.add_deepspeed_arguments(parser) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 3559ab88f..5c5bcd63a 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,23 +1,34 @@ import argparse +from typing import List, Optional, Union import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: @@ -30,6 +41,8 @@ def assert_extra_args(self, args, train_dataset_group): ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( @@ -46,17 +59,41 @@ def load_target_model(self, args, weight_dtype, accelerator): self.logit_scale = logit_scale self.ckpt_info = ckpt_info + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -69,15 +106,11 @@ def cache_text_encoder_outputs_if_needed( clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) @@ -146,7 +179,18 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -160,6 +204,12 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index d8422f083..6dec31def 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -1,5 +1,6 @@ import argparse import os +from typing import Optional, Union import regex @@ -8,8 +9,7 @@ init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util - +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -19,11 +19,13 @@ def __init__(self): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): # super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64 sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( @@ -42,28 +44,20 @@ def load_target_model(self, args, weight_dtype, accelerator): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -82,9 +76,11 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..9836da8b4 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,41 @@ +# Tests + +## Install + +``` +pip install pytest +``` + +## Usage + +``` +pytest +``` + +## Contribution + +Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests. + +Tests are functions starting with `test_` and files with the pattern `test_*.py`. + +``` +def test_x(): + assert 1 == 2, "Invalid test response" +``` + +## Resources + +### pytest + +- https://docs.pytest.org/en/stable/index.html +- https://docs.pytest.org/en/stable/how-to/assert.html +- https://docs.pytest.org/en/stable/how-to/doctest.html + +### PyTorch testing + +- https://circleci.com/blog/testing-pytorch-model-with-pytest/ +- https://pytorch.org/docs/stable/testing.html +- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +- https://github.com/huggingface/pytorch-image-models/tree/main/tests +- https://github.com/pytorch/pytorch/tree/main/test + diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py new file mode 100644 index 000000000..2ad7ce4ee --- /dev/null +++ b/tests/library/test_flux_train_utils.py @@ -0,0 +1,220 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from library.flux_train_utils import ( + get_noisy_model_input_and_timesteps, +) + +# Mock classes and functions +class MockNoiseScheduler: + def __init__(self, num_train_timesteps=1000): + self.config = MagicMock() + self.config.num_train_timesteps = num_train_timesteps + self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + + +# Create fixtures for commonly used objects +@pytest.fixture +def args(): + args = MagicMock() + args.timestep_sampling = "uniform" + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + args.ip_noise_gamma = None + args.ip_noise_gamma_random_strength = False + return args + + +@pytest.fixture +def noise_scheduler(): + return MockNoiseScheduler(num_train_timesteps=1000) + + +@pytest.fixture +def latents(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def noise(): + return torch.randn(2, 4, 8, 8) + + +@pytest.fixture +def device(): + # return "cuda" if torch.cuda.is_available() else "cpu" + return "cpu" + + +# Mock the required functions +@pytest.fixture(autouse=True) +def mock_functions(): + with ( + patch("torch.sigmoid", side_effect=torch.sigmoid), + patch("torch.rand", side_effect=torch.rand), + patch("torch.randn", side_effect=torch.randn), + ): + yield + + +# Test different timestep sampling methods +def test_uniform_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "uniform" + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "sigmoid" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "shift" + args.sigmoid_scale = 1.0 + args.discrete_flow_shift = 3.1582 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): + args.timestep_sampling = "flux_shift" + args.sigmoid_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_weighting_scheme(args, noise_scheduler, latents, noise, device): + # Mock the necessary functions for this specific test + with patch("library.flux_train_utils.compute_density_for_timestep_sampling", + return_value=torch.tensor([0.3, 0.7], device=device)), \ + patch("library.flux_train_utils.get_sigmas", + return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): + + args.timestep_sampling = "other" # Will trigger the weighting scheme path + args.weighting_scheme = "uniform" + args.logit_mean = 0.0 + args.logit_std = 1.0 + args.mode_scale = 1.0 + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype + ) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test IP noise options +def test_with_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.5 + args.ip_noise_gamma_random_strength = False + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): + args.ip_noise_gamma = 0.1 + args.ip_noise_gamma_random_strength = True + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (latents.shape[0],) + assert sigmas.shape == (latents.shape[0], 1, 1, 1) + + +# Test different data types +def test_float16_dtype(args, noise_scheduler, latents, noise, device): + dtype = torch.float16 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.dtype == dtype + assert timesteps.dtype == dtype + + +# Test different batch sizes +def test_different_batch_size(args, noise_scheduler, device): + latents = torch.randn(5, 4, 8, 8) # batch size of 5 + noise = torch.randn(5, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (5,) + assert sigmas.shape == (5, 1, 1, 1) + + +# Test different image sizes +def test_different_image_size(args, noise_scheduler, device): + latents = torch.randn(2, 4, 16, 16) # larger image size + noise = torch.randn(2, 4, 16, 16) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + assert sigmas.shape == (2, 1, 1, 1) + + +# Test edge cases +def test_zero_batch_size(args, noise_scheduler, device): + with pytest.raises(AssertionError): # expecting an error with zero batch size + latents = torch.randn(0, 4, 8, 8) + noise = torch.randn(0, 4, 8, 8) + dtype = torch.float32 + + get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + +def test_different_timestep_count(args, device): + noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count + latents = torch.randn(2, 4, 8, 8) + noise = torch.randn(2, 4, 8, 8) + dtype = torch.float32 + + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + + assert noisy_input.shape == latents.shape + assert timesteps.shape == (2,) + # Check that timesteps are within the proper range + assert torch.all(timesteps < 500) diff --git a/tests/library/test_lumina_models.py b/tests/library/test_lumina_models.py new file mode 100644 index 000000000..ba063688c --- /dev/null +++ b/tests/library/test_lumina_models.py @@ -0,0 +1,295 @@ +import pytest +import torch + +from library.lumina_models import ( + LuminaParams, + to_cuda, + to_cpu, + RopeEmbedder, + TimestepEmbedder, + modulate, + NextDiT, +) + +cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +def test_lumina_params(): + # Test default configuration + default_params = LuminaParams() + assert default_params.patch_size == 2 + assert default_params.in_channels == 4 + assert default_params.axes_dims == [36, 36, 36] + assert default_params.axes_lens == [300, 512, 512] + + # Test 2B config + config_2b = LuminaParams.get_2b_config() + assert config_2b.dim == 2304 + assert config_2b.in_channels == 16 + assert config_2b.n_layers == 26 + assert config_2b.n_heads == 24 + assert config_2b.cap_feat_dim == 2304 + + # Test 7B config + config_7b = LuminaParams.get_7b_config() + assert config_7b.dim == 4096 + assert config_7b.n_layers == 32 + assert config_7b.n_heads == 32 + assert config_7b.axes_dims == [64, 64, 64] + + +@cuda_required +def test_to_cuda_to_cpu(): + # Test tensor conversion + x = torch.tensor([1, 2, 3]) + x_cuda = to_cuda(x) + x_cpu = to_cpu(x_cuda) + assert x.cpu().tolist() == x_cpu.tolist() + + # Test list conversion + list_data = [torch.tensor([1]), torch.tensor([2])] + list_cuda = to_cuda(list_data) + assert all(tensor.device.type == "cuda" for tensor in list_cuda) + + list_cpu = to_cpu(list_cuda) + assert all(not tensor.device.type == "cuda" for tensor in list_cpu) + + # Test dict conversion + dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])} + dict_cuda = to_cuda(dict_data) + assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values()) + + dict_cpu = to_cpu(dict_cuda) + assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values()) + + +def test_timestep_embedder(): + # Test initialization + hidden_size = 256 + freq_emb_size = 128 + embedder = TimestepEmbedder(hidden_size, freq_emb_size) + assert embedder.frequency_embedding_size == freq_emb_size + + # Test timestep embedding + t = torch.tensor([0.5, 1.0, 2.0]) + emb_dim = freq_emb_size + embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim) + + assert embeddings.shape == (3, emb_dim) + assert embeddings.dtype == torch.float32 + + # Ensure embeddings are unique for different input times + assert not torch.allclose(embeddings[0], embeddings[1]) + + # Test forward pass + t_emb = embedder(t) + assert t_emb.shape == (3, hidden_size) + + +def test_rope_embedder_simple(): + rope_embedder = RopeEmbedder() + batch_size, seq_len = 2, 10 + + # Create position_ids with valid ranges for each axis + position_ids = torch.stack( + [ + torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511 + torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511 + ], + dim=-1, + ) + + freqs_cis = rope_embedder(position_ids) + # RoPE embeddings work in pairs, so output dimension is half of total axes_dims + expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64 + assert freqs_cis.shape == (batch_size, seq_len, expected_dim) + + +def test_modulate(): + # Test modulation with different scales + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + scale = torch.tensor([1.5, 2.0]) + + modulated_x = modulate(x, scale) + + # Check that modulation scales correctly + # The function does x * (1 + scale), so: + # For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0] + expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]]) + # Which equals: [[2.5, 5.0], [9.0, 12.0]] + + assert torch.allclose(modulated_x, expected_x) + + +def test_nextdit_parameter_count_optimized(): + # The constraint is: (dim // n_heads) == sum(axes_dims) + # So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model_small = NextDiT( + patch_size=2, + in_channels=4, # Smaller + dim=120, # 120 // 4 = 30 + n_layers=2, # Much fewer layers + n_heads=4, # Fewer heads + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], # Smaller + ) + param_count_small = model_small.parameter_count() + assert param_count_small > 0 + + # For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32 + model_medium = NextDiT( + patch_size=2, + in_channels=4, + dim=192, # 192 // 6 = 32 + n_layers=4, # More layers + n_heads=6, + n_kv_heads=3, + axes_dims=[10, 11, 11], # sum = 32 + axes_lens=[10, 32, 32], + ) + param_count_medium = model_medium.parameter_count() + assert param_count_medium > param_count_small + print(f"Small model: {param_count_small:,} parameters") + print(f"Medium model: {param_count_medium:,} parameters") + + +@torch.no_grad() +def test_precompute_freqs_cis(): + # Test precompute_freqs_cis + dim = [16, 56, 56] + end = [1, 512, 512] + theta = 10000.0 + + freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta) + + # Check number of frequency tensors + assert len(freqs_cis) == len(dim) + + # Check each frequency tensor + for i, (d, e) in enumerate(zip(dim, end)): + assert freqs_cis[i].shape == (e, d // 2) + assert freqs_cis[i].dtype == torch.complex128 + + +@torch.no_grad() +def test_nextdit_patchify_and_embed(): + """Test the patchify_and_embed method which is crucial for training""" + # Create a small NextDiT model for testing + # The constraint is: (dim // n_heads) == sum(axes_dims) + # For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30 + model = NextDiT( + patch_size=2, + in_channels=4, + dim=120, # 120 // 4 = 30 + n_layers=1, # Minimal layers for faster testing + n_refiner_layers=1, # Minimal refiner layers + n_heads=4, + n_kv_heads=2, + axes_dims=[10, 10, 10], # sum = 30 + axes_lens=[10, 32, 32], + cap_feat_dim=120, # Match dim for consistency + ) + + # Prepare test inputs + batch_size = 2 + height, width = 64, 64 # Must be divisible by patch_size (2) + caption_seq_len = 8 + + # Create mock inputs + x = torch.randn(batch_size, 4, height, width) # Image latents + cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features + cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens + # Make second batch have shorter caption + cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch + t = torch.randn(batch_size, 120) # Timestep embeddings + + # Call patchify_and_embed + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # Validate outputs + image_seq_len = (height // 2) * (width // 2) # patch_size = 2 + expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption + max_seq_len = max(expected_seq_lengths) + + # Check joint hidden states shape + assert joint_hidden_states.shape == (batch_size, max_seq_len, 120) + assert joint_hidden_states.dtype == torch.float32 + + # Check attention mask shape and values + assert attention_mask.shape == (batch_size, max_seq_len) + assert attention_mask.dtype == torch.bool + # First batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[0, : expected_seq_lengths[0]]) + assert torch.all(~attention_mask[0, expected_seq_lengths[0] :]) + # Second batch should have all positions valid up to its sequence length + assert torch.all(attention_mask[1, : expected_seq_lengths[1]]) + assert torch.all(~attention_mask[1, expected_seq_lengths[1] :]) + + # Check freqs_cis shape + assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2) + + # Check effective caption lengths + assert l_effective_cap_len == [caption_seq_len, 6] + + # Check sequence lengths + assert seq_lengths == expected_seq_lengths + + # Validate that the joint hidden states contain non-zero values where attention mask is True + for i in range(batch_size): + valid_positions = attention_mask[i] + # Check that valid positions have meaningful data (not all zeros) + valid_data = joint_hidden_states[i][valid_positions] + assert not torch.allclose(valid_data, torch.zeros_like(valid_data)) + + # Check that invalid positions are zeros + if valid_positions.sum() < max_seq_len: + invalid_data = joint_hidden_states[i][~valid_positions] + assert torch.allclose(invalid_data, torch.zeros_like(invalid_data)) + + +@torch.no_grad() +def test_nextdit_patchify_and_embed_edge_cases(): + """Test edge cases for patchify_and_embed""" + # Create minimal model + model = NextDiT( + patch_size=2, + in_channels=4, + dim=60, # 60 // 3 = 20 + n_layers=1, + n_refiner_layers=1, + n_heads=3, + n_kv_heads=1, + axes_dims=[8, 6, 6], # sum = 20 + axes_lens=[10, 16, 16], + cap_feat_dim=60, + ) + + # Test with empty captions (all masked) + batch_size = 1 + height, width = 32, 32 + caption_seq_len = 4 + + x = torch.randn(batch_size, 4, height, width) + cap_feats = torch.randn(batch_size, caption_seq_len, 60) + cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked + t = torch.randn(batch_size, 60) + + joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed( + x, cap_feats, cap_mask, t + ) + + # With all captions masked, effective length should be 0 + assert l_effective_cap_len == [0] + + # Sequence length should just be the image sequence length + image_seq_len = (height // 2) * (width // 2) + assert seq_lengths == [image_seq_len] + + # Joint hidden states should only contain image data + assert joint_hidden_states.shape == (batch_size, image_seq_len, 60) + assert attention_mask.shape == (batch_size, image_seq_len) + assert torch.all(attention_mask[0]) # All image positions should be valid diff --git a/tests/library/test_lumina_train_util.py b/tests/library/test_lumina_train_util.py new file mode 100644 index 000000000..bcf448c89 --- /dev/null +++ b/tests/library/test_lumina_train_util.py @@ -0,0 +1,241 @@ +import pytest +import torch +import math + +from library.lumina_train_util import ( + batchify, + time_shift, + get_lin_function, + get_schedule, + compute_density_for_timestep_sampling, + get_sigmas, + compute_loss_weighting_for_sd3, + get_noisy_model_input_and_timesteps, + apply_model_prediction_type, + retrieve_timesteps, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + + +def test_batchify(): + # Test case with no batch size specified + prompts = [ + {"prompt": "test1"}, + {"prompt": "test2"}, + {"prompt": "test3"} + ] + batchified = list(batchify(prompts)) + assert len(batchified) == 1 + assert len(batchified[0]) == 3 + + # Test case with batch size specified + batchified_sized = list(batchify(prompts, batch_size=2)) + assert len(batchified_sized) == 2 + assert len(batchified_sized[0]) == 2 + assert len(batchified_sized[1]) == 1 + + # Test batching with prompts having same parameters + prompts_with_params = [ + {"prompt": "test1", "width": 512, "height": 512}, + {"prompt": "test2", "width": 512, "height": 512}, + {"prompt": "test3", "width": 1024, "height": 1024} + ] + batchified_params = list(batchify(prompts_with_params)) + assert len(batchified_params) == 2 + + # Test invalid batch size + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=0)) + with pytest.raises(ValueError): + list(batchify(prompts, batch_size=-1)) + + +def test_time_shift(): + # Test standard parameters + t = torch.tensor([0.5]) + mu = 1.0 + sigma = 1.0 + result = time_shift(mu, sigma, t) + assert 0 <= result <= 1 + + # Test with edge cases + t_edges = torch.tensor([0.0, 1.0]) + result_edges = time_shift(1.0, 1.0, t_edges) + + # Check that results are bounded within [0, 1] + assert torch.all(result_edges >= 0) + assert torch.all(result_edges <= 1) + + +def test_get_lin_function(): + # Default parameters + func = get_lin_function() + assert func(256) == 0.5 + assert func(4096) == 1.15 + + # Custom parameters + custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9) + assert custom_func(100) == 0.1 + assert custom_func(1000) == 0.9 + + +def test_get_schedule(): + # Basic schedule + schedule = get_schedule(num_steps=10, image_seq_len=256) + assert len(schedule) == 10 + assert all(0 <= x <= 1 for x in schedule) + + # Test different sequence lengths + short_schedule = get_schedule(num_steps=5, image_seq_len=128) + long_schedule = get_schedule(num_steps=15, image_seq_len=1024) + assert len(short_schedule) == 5 + assert len(long_schedule) == 15 + + # Test with shift disabled + unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False) + assert torch.allclose( + torch.tensor(unshifted_schedule), + torch.linspace(1, 1/10, 10) + ) + + +def test_compute_density_for_timestep_sampling(): + # Test uniform sampling + uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100) + assert len(uniform_samples) == 100 + assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1)) + + # Test logit normal sampling + logit_normal_samples = compute_density_for_timestep_sampling( + "logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0 + ) + assert len(logit_normal_samples) == 100 + assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1)) + + # Test mode sampling + mode_samples = compute_density_for_timestep_sampling( + "mode", batch_size=100, mode_scale=0.5 + ) + assert len(mode_samples) == 100 + assert torch.all((mode_samples >= 0) & (mode_samples <= 1)) + + +def test_get_sigmas(): + # Create a mock noise scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Test with default parameters + timesteps = torch.tensor([100, 500, 900]) + sigmas = get_sigmas(scheduler, timesteps, device) + + # Check shape and basic properties + assert sigmas.shape[0] == 3 + assert torch.all(sigmas >= 0) + + # Test with different n_dim + sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4) + assert sigmas_4d.ndim == 4 + + # Test with different dtype + sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16) + assert sigmas_float16.dtype == torch.float16 + + +def test_compute_loss_weighting_for_sd3(): + # Prepare some mock sigmas + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test sigma_sqrt weighting + sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas) + assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5) + + # Test cosmap weighting + cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas) + bot = 1 - 2 * sigmas + 2 * sigmas**2 + expected_cosmap = 2 / (math.pi * bot) + assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5) + + # Test default weighting + default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas) + assert torch.all(default_weighting == 1) + + +def test_apply_model_prediction_type(): + # Create mock args and tensors + class MockArgs: + model_prediction_type = "raw" + weighting_scheme = "sigma_sqrt" + + args = MockArgs() + model_pred = torch.tensor([1.0, 2.0, 3.0]) + noisy_model_input = torch.tensor([0.5, 1.0, 1.5]) + sigmas = torch.tensor([0.1, 0.5, 1.0]) + + # Test raw prediction type + raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(raw_pred == model_pred) + assert raw_weighting is None + + # Test additive prediction type + args.model_prediction_type = "additive" + additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(additive_pred == model_pred + noisy_model_input) + + # Test sigma scaled prediction type + args.model_prediction_type = "sigma_scaled" + sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input) + assert sigma_weighting is not None + + +def test_retrieve_timesteps(): + # Create a mock scheduler + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + + # Test with num_inference_steps + timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50) + assert len(timesteps) == 50 + assert n_steps == 50 + + # Test error handling with simultaneous timesteps and sigmas + with pytest.raises(ValueError): + retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3]) + + +def test_get_noisy_model_input_and_timesteps(): + # Create a mock args and setup + class MockArgs: + timestep_sampling = "uniform" + weighting_scheme = "sigma_sqrt" + sigmoid_scale = 1.0 + discrete_flow_shift = 6.0 + + args = MockArgs() + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + device = torch.device('cpu') + + # Prepare mock latents and noise + latents = torch.randn(4, 16, 64, 64) + noise = torch.randn_like(latents) + + # Test uniform sampling + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + + # Validate output shapes and types + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] + assert noisy_input.dtype == torch.float32 + assert timesteps.dtype == torch.float32 + + # Test different sampling methods + sampling_methods = ["sigmoid", "shift", "nextdit_shift"] + for method in sampling_methods: + args.timestep_sampling = method + noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps( + args, scheduler, latents, noise, device, torch.float32 + ) + assert noisy_input.shape == latents.shape + assert timesteps.shape[0] == latents.shape[0] diff --git a/tests/library/test_lumina_util.py b/tests/library/test_lumina_util.py new file mode 100644 index 000000000..397bab5a9 --- /dev/null +++ b/tests/library/test_lumina_util.py @@ -0,0 +1,112 @@ +import torch +from torch.nn.modules import conv + +from library import lumina_util + + +def test_unpack_latents(): + # Create a test tensor + # Shape: [batch, height*width, channels*patch_height*patch_width] + x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels + packed_latent_height = 2 + packed_latent_width = 2 + + # Unpack the latents + unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width) + + # Check output shape + # Expected shape: [batch, channels, height*patch_height, width*patch_width] + assert unpacked.shape == (2, 4, 4, 4) + + +def test_pack_latents(): + # Create a test tensor + # Shape: [batch, channels, height*patch_height, width*patch_width] + x = torch.randn(2, 4, 4, 4) + + # Pack the latents + packed = lumina_util.pack_latents(x) + + # Check output shape + # Expected shape: [batch, height*width, channels*patch_height*patch_width] + assert packed.shape == (2, 4, 16) + + +def test_convert_diffusers_sd_to_alpha_vllm(): + num_double_blocks = 2 + # Predefined test cases based on the actual conversion map + test_cases = [ + # Static key conversions with possible list mappings + { + "original_keys": ["time_caption_embed.caption_embedder.0.weight"], + "original_pattern": ["time_caption_embed.caption_embedder.0.weight"], + "expected_converted_keys": ["cap_embedder.0.weight"], + }, + { + "original_keys": ["patch_embedder.proj.weight"], + "original_pattern": ["patch_embedder.proj.weight"], + "expected_converted_keys": ["x_embedder.weight"], + }, + { + "original_keys": ["transformer_blocks.0.norm1.weight"], + "original_pattern": ["transformer_blocks.().norm1.weight"], + "expected_converted_keys": ["layers.0.attention_norm1.weight"], + }, + ] + + + for test_case in test_cases: + for original_key, original_pattern, expected_converted_key in zip( + test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"] + ): + # Create test state dict + test_sd = {original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion (handle both string and list keys) + # Find the correct converted key + match_found = False + if expected_converted_key in converted_sd: + # Verify tensor preservation + assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), ( + f"Tensor mismatch for {original_key}" + ) + match_found = True + break + + assert match_found, f"Failed to convert {original_key}" + + # Ensure original key is also present + assert original_key in converted_sd + + # Test with block-specific keys + block_specific_cases = [ + { + "original_pattern": "transformer_blocks.().norm1.weight", + "converted_pattern": "layers.().attention_norm1.weight", + } + ] + + for case in block_specific_cases: + for block_idx in range(2): # Test multiple block indices + # Prepare block-specific keys + block_original_key = case["original_pattern"].replace("()", str(block_idx)) + block_converted_key = case["converted_pattern"].replace("()", str(block_idx)) + print(block_original_key, block_converted_key) + + # Create test state dict + test_sd = {block_original_key: torch.randn(10, 10)} + + # Convert the state dict + converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks) + + # Verify conversion + # assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}" + assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), ( + f"Tensor mismatch for block key {block_original_key}" + ) + + # Ensure original key is also present + assert block_original_key in converted_sd diff --git a/tests/library/test_sai_model_spec.py b/tests/library/test_sai_model_spec.py new file mode 100644 index 000000000..0bbfa1167 --- /dev/null +++ b/tests/library/test_sai_model_spec.py @@ -0,0 +1,360 @@ +"""Tests for sai_model_spec module.""" + +import pytest +import time + +from library import sai_model_spec + + +class MockArgs: + """Mock argparse.Namespace for testing.""" + + def __init__(self, **kwargs): + # Default values + self.v2 = False + self.v_parameterization = False + self.resolution = 512 + self.metadata_title = None + self.metadata_author = None + self.metadata_description = None + self.metadata_license = None + self.metadata_tags = None + self.min_timestep = None + self.max_timestep = None + self.clip_skip = None + self.output_name = "test_output" + + # Override with provided values + for key, value in kwargs.items(): + setattr(self, key, value) + + +class TestModelSpecMetadata: + """Test the ModelSpecMetadata dataclass.""" + + def test_creation_and_conversion(self): + """Test creating dataclass and converting to metadata dict.""" + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + resolution="512x512", + author="Test Author", + description=None, # Test None exclusion + ) + + assert metadata.architecture == "stable-diffusion-v1" + assert metadata.sai_model_spec == "1.0.1" + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.architecture" in metadata_dict + assert "modelspec.author" in metadata_dict + assert "modelspec.description" not in metadata_dict # None values excluded + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + + def test_additional_fields_handling(self): + """Test handling of additional metadata fields.""" + additional = {"custom_field": "custom_value", "modelspec.prefixed": "prefixed_value"} + + metadata = sai_model_spec.ModelSpecMetadata( + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + resolution="512x512", + additional_fields=additional, + ) + + metadata_dict = metadata.to_metadata_dict() + assert "modelspec.custom_field" in metadata_dict + assert "modelspec.prefixed" in metadata_dict + assert metadata_dict["modelspec.custom_field"] == "custom_value" + + def test_from_args_extraction(self): + """Test creating ModelSpecMetadata from args with metadata_* fields.""" + args = MockArgs(metadata_author="Test Author", metadata_trigger_phrase="anime style", metadata_usage_hint="Use CFG 7.5") + + metadata = sai_model_spec.ModelSpecMetadata.from_args( + args, + architecture="stable-diffusion-v1", + implementation="diffusers", + title="Test Model", + resolution="512x512", + ) + + assert metadata.author == "Test Author" + assert metadata.additional_fields["trigger_phrase"] == "anime style" + assert metadata.additional_fields["usage_hint"] == "Use CFG 7.5" + + +class TestArchitectureDetection: + """Test architecture detection for different model types.""" + + @pytest.mark.parametrize( + "config,expected", + [ + ({"v2": False, "v_parameterization": False, "sdxl": True}, "stable-diffusion-xl-v1-base"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "dev"}}, "flux-1-dev"), + ({"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"flux": "chroma"}}, "chroma"), + ( + {"v2": False, "v_parameterization": False, "sdxl": False, "model_config": {"sd3": "large"}}, + "stable-diffusion-3-large", + ), + ({"v2": True, "v_parameterization": True, "sdxl": False}, "stable-diffusion-v2-768-v"), + ({"v2": False, "v_parameterization": False, "sdxl": False}, "stable-diffusion-v1"), + ], + ) + def test_architecture_detection(self, config, expected): + """Test architecture detection for various model configurations.""" + model_config = config.pop("model_config", None) + arch = sai_model_spec.determine_architecture(lora=False, textual_inversion=False, model_config=model_config, **config) + assert arch == expected + + def test_adapter_suffixes(self): + """Test LoRA and textual inversion suffixes.""" + lora_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=True, lora=True, textual_inversion=False + ) + assert lora_arch == "stable-diffusion-xl-v1-base/lora" + + ti_arch = sai_model_spec.determine_architecture( + v2=False, v_parameterization=False, sdxl=False, lora=False, textual_inversion=True + ) + assert ti_arch == "stable-diffusion-v1/textual-inversion" + + +class TestImplementationDetection: + """Test implementation detection for different model types.""" + + @pytest.mark.parametrize( + "config,expected", + [ + ({"model_config": {"flux": "dev"}}, "https://github.com/black-forest-labs/flux"), + ({"model_config": {"flux": "chroma"}}, "https://huggingface.co/lodestones/Chroma"), + ({"model_config": {"lumina": "lumina2"}}, "https://github.com/Alpha-VLLM/Lumina-Image-2.0"), + ({"lora": True, "sdxl": True}, "https://github.com/Stability-AI/generative-models"), + ({"lora": True, "sdxl": False}, "diffusers"), + ], + ) + def test_implementation_detection(self, config, expected): + """Test implementation detection for various configurations.""" + model_config = config.pop("model_config", None) + impl = sai_model_spec.determine_implementation( + lora=config.get("lora", False), textual_inversion=False, sdxl=config.get("sdxl", False), model_config=model_config + ) + assert impl == expected + + +class TestResolutionHandling: + """Test resolution parsing and defaults.""" + + @pytest.mark.parametrize( + "input_reso,expected", + [ + ((768, 1024), "768x1024"), + (768, "768x768"), + ("768,1024", "768x1024"), + ], + ) + def test_explicit_resolution_formats(self, input_reso, expected): + """Test different resolution input formats.""" + res = sai_model_spec.determine_resolution(reso=input_reso) + assert res == expected + + @pytest.mark.parametrize( + "config,expected", + [ + ({"sdxl": True}, "1024x1024"), + ({"model_config": {"flux": "dev"}}, "1024x1024"), + ({"v2": True, "v_parameterization": True}, "768x768"), + ({}, "512x512"), # Default SD v1 + ], + ) + def test_default_resolutions(self, config, expected): + """Test default resolution detection.""" + model_config = config.pop("model_config", None) + res = sai_model_spec.determine_resolution(model_config=model_config, **config) + assert res == expected + + +class TestThumbnailProcessing: + """Test thumbnail data URL processing.""" + + def test_file_to_data_url(self): + """Test converting file to data URL.""" + import tempfile + import os + + # Create a tiny test PNG (1x1 pixel) + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + data_url = sai_model_spec.file_to_data_url(temp_path) + + # Check format + assert data_url.startswith("data:image/png;base64,") + + # Check it's a reasonable length (base64 encoded) + assert len(data_url) > 50 + + # Verify we can decode it back + import base64 + + encoded_part = data_url.split(",", 1)[1] + decoded_data = base64.b64decode(encoded_part) + assert decoded_data == test_png_data + + finally: + os.unlink(temp_path) + + def test_file_to_data_url_nonexistent_file(self): + """Test error handling for nonexistent files.""" + import pytest + + with pytest.raises(FileNotFoundError): + sai_model_spec.file_to_data_url("/nonexistent/file.png") + + def test_thumbnail_processing_in_metadata(self): + """Test thumbnail processing in build_metadata_dataclass.""" + import tempfile + import os + + # Create a test image file + test_png_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xff\xff\xff\x00\x00\x00\x04\x00\x01\x9d\xb3\xa7c\x00\x00\x00\x00IEND\xaeB`\x82" + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(test_png_data) + temp_path = f.name + + try: + timestamp = time.time() + + # Test with file path - should be converted to data URL + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": temp_path}, + ) + + # Should be converted to data URL + assert "thumbnail" in metadata.additional_fields + assert metadata.additional_fields["thumbnail"].startswith("data:image/png;base64,") + + finally: + os.unlink(temp_path) + + def test_thumbnail_data_url_passthrough(self): + """Test that existing data URLs are passed through unchanged.""" + timestamp = time.time() + + existing_data_url = ( + "" + ) + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": existing_data_url}, + ) + + # Should be unchanged + assert metadata.additional_fields["thumbnail"] == existing_data_url + + def test_thumbnail_invalid_file_handling(self): + """Test graceful handling of invalid thumbnail files.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + optional_metadata={"thumbnail": "/nonexistent/file.png"}, + ) + + # Should be removed from additional_fields due to error + assert "thumbnail" not in metadata.additional_fields + + +class TestBuildMetadataIntegration: + """Test the complete metadata building workflow.""" + + def test_sdxl_model_workflow(self): + """Test complete workflow for SDXL model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test SDXL Model", + ) + + assert metadata.architecture == "stable-diffusion-xl-v1-base" + assert metadata.implementation == "https://github.com/Stability-AI/generative-models" + assert metadata.resolution == "1024x1024" + assert metadata.prediction_type == "epsilon" + + def test_flux_model_workflow(self): + """Test complete workflow for Flux model.""" + timestamp = time.time() + + metadata = sai_model_spec.build_metadata_dataclass( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=False, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Flux Model", + model_config={"flux": "dev"}, + optional_metadata={"trigger_phrase": "anime style"}, + ) + + assert metadata.architecture == "flux-1-dev" + assert metadata.implementation == "https://github.com/black-forest-labs/flux" + assert metadata.prediction_type is None # Flux doesn't use prediction_type + assert metadata.additional_fields["trigger_phrase"] == "anime style" + + def test_legacy_function_compatibility(self): + """Test that legacy build_metadata function works correctly.""" + timestamp = time.time() + + metadata_dict = sai_model_spec.build_metadata( + state_dict=None, + v2=False, + v_parameterization=False, + sdxl=True, + lora=False, + textual_inversion=False, + timestamp=timestamp, + title="Test Model", + ) + + assert isinstance(metadata_dict, dict) + assert metadata_dict["modelspec.sai_model_spec"] == "1.0.1" + assert metadata_dict["modelspec.architecture"] == "stable-diffusion-xl-v1-base" diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py new file mode 100644 index 000000000..d77d27383 --- /dev/null +++ b/tests/library/test_strategy_lumina.py @@ -0,0 +1,241 @@ +import os +import tempfile +import torch +import numpy as np +from unittest.mock import patch +from transformers import Gemma2Model + +from library.strategy_lumina import ( + LuminaTokenizeStrategy, + LuminaTextEncodingStrategy, + LuminaTextEncoderOutputsCachingStrategy, + LuminaLatentsCachingStrategy, +) + + +class SimpleMockGemma2Model: + """Lightweight mock that avoids initializing the actual Gemma2Model""" + + def __init__(self, hidden_size=2304): + self.device = torch.device("cpu") + self._hidden_size = hidden_size + self._orig_mod = self # For dynamic compilation compatibility + + def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False): + # Create a mock output object with hidden states + batch_size, seq_len = input_ids.shape + hidden_size = self._hidden_size + + class MockOutput: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + + mock_hidden_states = [ + torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device) + for _ in range(3) # Mimic multiple layers of hidden states + ] + + return MockOutput(mock_hidden_states) + + +def test_lumina_tokenize_strategy(): + # Test default initialization + try: + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return + assert tokenize_strategy.max_length == 256 + assert tokenize_strategy.tokenizer.padding_side == "right" + + # Test tokenization of a single string + text = "Hello" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + assert tokens.ndim == 2 + assert attention_mask.ndim == 2 + assert tokens.shape == attention_mask.shape + assert tokens.shape[1] == 256 # max_length + + # Test tokenize_with_weights + tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text) + assert len(weights) == 1 + assert torch.all(weights[0] == 1) + + +def test_lumina_text_encoding_strategy(): + # Create strategies + try: + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return + encoding_strategy = LuminaTextEncodingStrategy() + + # Create a mock model + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Prepare sample text + text = "Test encoding strategy" + tokens, attention_mask = tokenize_strategy.tokenize(text) + + # Perform encoding + hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens( + tokenize_strategy, [mock_model], (tokens, attention_mask) + ) + + # Validate outputs + assert original_isinstance(hidden_states, torch.Tensor) + assert original_isinstance(input_ids, torch.Tensor) + assert original_isinstance(attention_masks, torch.Tensor) + + # Check the shape of the second-to-last hidden state + assert hidden_states.ndim == 3 + + # Test weighted encoding (which falls back to standard encoding for Lumina) + weights = [torch.ones_like(tokens)] + hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [mock_model], (tokens, attention_mask), weights + ) + + # For the mock, we can't guarantee identical outputs since each call returns random tensors + # Instead, check that the outputs have the same shape and are tensors + assert hidden_states_w.shape == hidden_states.shape + assert original_isinstance(hidden_states_w, torch.Tensor) + assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same + assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same + + +def test_lumina_text_encoder_outputs_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Create a cache file path + cache_file = os.path.join(tmpdir, "test_outputs.npz") + + # Create the caching strategy + caching_strategy = LuminaTextEncoderOutputsCachingStrategy( + cache_to_disk=True, + batch_size=1, + skip_disk_cache_validity_check=False, + ) + + # Create a mock class for ImageInfo + class MockImageInfo: + def __init__(self, caption, cache_path): + self.caption = caption + self.text_encoder_outputs_npz = cache_path + + # Create a sample input info + image_info = MockImageInfo("Test caption", cache_file) + + # Simulate a batch + batch = [image_info] + + # Create mock strategies and model + try: + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + return + encoding_strategy = LuminaTextEncodingStrategy() + mock_model = SimpleMockGemma2Model() + + # Patch the isinstance check to accept our simple mock + original_isinstance = isinstance + with patch("library.strategy_lumina.isinstance") as mock_isinstance: + + def custom_isinstance(obj, class_or_tuple): + if obj is mock_model and class_or_tuple is Gemma2Model: + return True + if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model: + return True + return original_isinstance(obj, class_or_tuple) + + mock_isinstance.side_effect = custom_isinstance + + # Call cache_batch_outputs + caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch) + + # Verify the npz file was created + assert os.path.exists(cache_file), f"Cache file not created at {cache_file}" + + # Verify the is_disk_cached_outputs_expected method + assert caching_strategy.is_disk_cached_outputs_expected(cache_file) + + # Test loading from npz + loaded_data = caching_strategy.load_outputs_npz(cache_file) + assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask + + +def test_lumina_latents_caching_strategy(): + # Create a temporary directory for caching + with tempfile.TemporaryDirectory() as tmpdir: + # Prepare a mock absolute path + abs_path = os.path.join(tmpdir, "test_image.png") + + # Use smaller image size for faster testing + image_size = (64, 64) + + # Create a smaller dummy image for testing + test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + # Create the caching strategy + caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False) + + # Create a simple mock VAE + class MockVAE: + def __init__(self): + self.device = torch.device("cpu") + self.dtype = torch.float32 + + def encode(self, x): + # Return smaller encoded tensor for faster processing + encoded = torch.randn(1, 4, 8, 8, device=x.device) + return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded}) + + # Prepare a mock batch + class MockImageInfo: + def __init__(self, path, image): + self.absolute_path = path + self.image = image + self.image_path = path + self.bucket_reso = image_size + self.resized_size = image_size + self.resize_interpolation = "lanczos" + # Specify full path to the latents npz file + self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz") + + batch = [MockImageInfo(abs_path, test_image)] + + # Call cache_batch_latents + mock_vae = MockVAE() + caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False) + + # Generate the expected npz path + npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size) + + # Verify the file was created + assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}" + + # Verify is_disk_cached_latents_expected + assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False) + + # Test loading from disk + loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size) + assert len(loaded_data) == 5 # Check for 5 expected elements diff --git a/tests/test_custom_offloading_utils.py b/tests/test_custom_offloading_utils.py new file mode 100644 index 000000000..8c23bdf55 --- /dev/null +++ b/tests/test_custom_offloading_utils.py @@ -0,0 +1,408 @@ +import pytest +import torch +import torch.nn as nn +from unittest.mock import patch, MagicMock + +from library.custom_offloading_utils import ( + _synchronize_device, + swap_weight_devices_cuda, + swap_weight_devices_no_cuda, + weighs_to_device, + Offloader, + ModelOffloader +) + +class TransformerBlock(nn.Module): + def __init__(self, block_idx: int): + super().__init__() + self.block_idx = block_idx + self.linear1 = nn.Linear(10, 5) + self.linear2 = nn.Linear(5, 10) + self.seq = nn.Sequential(nn.SiLU(), nn.Linear(10, 10)) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + x = self.seq(x) + return x + + +class SimpleModel(nn.Module): + def __init__(self, num_blocks=16): + super().__init__() + self.blocks = nn.ModuleList([ + TransformerBlock(i) + for i in range(num_blocks)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + +# Device Synchronization Tests +@patch('torch.cuda.synchronize') +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_synchronize(mock_cuda_sync): + device = torch.device('cuda') + _synchronize_device(device) + mock_cuda_sync.assert_called_once() + +@patch('torch.xpu.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="XPU not available") +def test_xpu_synchronize(mock_xpu_sync): + device = torch.device('xpu') + _synchronize_device(device) + mock_xpu_sync.assert_called_once() + +@patch('torch.mps.synchronize') +@pytest.mark.skipif(not torch.xpu.is_available(), reason="MPS not available") +def test_mps_synchronize(mock_mps_sync): + device = torch.device('mps') + _synchronize_device(device) + mock_mps_sync.assert_called_once() + + +# Weights to Device Tests +def test_weights_to_device(): + # Create a simple model with weights + model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 2) + ) + + # Start with CPU tensors + device = torch.device('cpu') + for module in model.modules(): + if hasattr(module, "weight") and module.weight is not None: + assert module.weight.device == device + + # Move to mock CUDA device + mock_device = torch.device('cuda') + with patch('torch.Tensor.to', return_value=torch.zeros(1).to(device)): + weighs_to_device(model, mock_device) + + # Since we mocked the to() function, we can only verify modules were processed + # but can't check actual device movement + + +# Swap Weight Devices Tests +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_swap_weight_devices_cuda(): + device = torch.device('cuda') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + # Move layer to CUDA to move to CPU + layer_to_cpu.to(device) + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + + assert layer_to_cpu.device.type == 'cpu' + assert layer_to_cuda.device.type == 'cuda' + + + +@patch('library.custom_offloading_utils._synchronize_device') +def test_swap_weight_devices_no_cuda(mock_sync_device): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = SimpleModel() + + with patch('torch.Tensor.to', return_value=torch.zeros(1)): + with patch('torch.Tensor.copy_'): + swap_weight_devices_no_cuda(device, layer_to_cpu, layer_to_cuda) + + # Verify _synchronize_device was called twice + assert mock_sync_device.call_count == 2 + + +# Offloader Tests +@pytest.fixture +def offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + return Offloader( + num_blocks=4, + blocks_to_swap=2, + device=device, + debug=False + ) + + +def test_offloader_init(offloader): + assert offloader.num_blocks == 4 + assert offloader.blocks_to_swap == 2 + assert hasattr(offloader, 'thread_pool') + assert offloader.futures == {} + assert offloader.cuda_available == (offloader.device.type == 'cuda') + + +@patch('library.custom_offloading_utils.swap_weight_devices_cuda') +@patch('library.custom_offloading_utils.swap_weight_devices_no_cuda') +def test_swap_weight_devices(mock_no_cuda, mock_cuda, offloader: Offloader): + block_to_cpu = SimpleModel() + block_to_cuda = SimpleModel() + + # Force test for CUDA device + offloader.cuda_available = True + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_no_cuda.assert_not_called() + + # Reset mocks + mock_cuda.reset_mock() + mock_no_cuda.reset_mock() + + # Force test for non-CUDA device + offloader.cuda_available = False + offloader.swap_weight_devices(block_to_cpu, block_to_cuda) + mock_no_cuda.assert_called_once_with(offloader.device, block_to_cpu, block_to_cuda) + mock_cuda.assert_not_called() + + +@patch('library.custom_offloading_utils.Offloader.swap_weight_devices') +def test_submit_move_blocks(mock_swap, offloader): + blocks = [SimpleModel() for _ in range(4)] + block_idx_to_cpu = 0 + block_idx_to_cuda = 2 + + # Mock the thread pool to execute synchronously + future = MagicMock() + future.result.return_value = (block_idx_to_cpu, block_idx_to_cuda) + offloader.thread_pool.submit = MagicMock(return_value=future) + + offloader._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + # Check that the future is stored with the correct key + assert block_idx_to_cuda in offloader.futures + + +def test_wait_blocks_move(offloader): + block_idx = 2 + + # Test with no future for the block + offloader._wait_blocks_move(block_idx) # Should not raise + + # Create a fake future and test waiting + future = MagicMock() + future.result.return_value = (0, block_idx) + offloader.futures[block_idx] = future + + offloader._wait_blocks_move(block_idx) + + # Check that the future was removed + assert block_idx not in offloader.futures + future.result.assert_called_once() + + +# ModelOffloader Tests +@pytest.fixture +def model_offloader(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + return ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + +def test_model_offloader_init(model_offloader): + assert model_offloader.num_blocks == 4 + assert model_offloader.blocks_to_swap == 2 + assert hasattr(model_offloader, 'thread_pool') + assert model_offloader.futures == {} + assert len(model_offloader.remove_handles) > 0 # Should have registered hooks + + +def test_create_backward_hook(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + blocks = SimpleModel(4).blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test hook creation for swapping case (block 0) + hook_swap = model_offloader.create_backward_hook(blocks, 0) + assert hook_swap is None + + # Test hook creation for waiting case (block 1) + hook_wait = model_offloader.create_backward_hook(blocks, 1) + assert hook_wait is not None + + # Test hook creation for no action case (block 3) + hook_none = model_offloader.create_backward_hook(blocks, 3) + assert hook_none is None + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_backward_hook_execution(mock_wait, mock_submit): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + blocks_to_swap = 2 + model = SimpleModel(4) + blocks = model.blocks + model_offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=blocks_to_swap, + device=device, + debug=False + ) + + # Test swapping hook (block 1) + hook_swap = model_offloader.create_backward_hook(blocks, 1) + assert hook_swap is not None + hook_swap(model, torch.zeros(1), torch.zeros(1)) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test waiting hook (block 2) + hook_wait = model_offloader.create_backward_hook(blocks, 2) + assert hook_wait is not None + hook_wait(model, torch.zeros(1), torch.zeros(1)) + assert mock_wait.call_count == 2 + + +@patch('library.custom_offloading_utils.weighs_to_device') +@patch('library.custom_offloading_utils._synchronize_device') +@patch('library.custom_offloading_utils._clean_memory_on_device') +def test_prepare_block_devices_before_forward(mock_clean, mock_sync, mock_weights_to_device, model_offloader): + model = SimpleModel(4) + blocks = model.blocks + + with patch.object(nn.Module, 'to'): + model_offloader.prepare_block_devices_before_forward(blocks) + + # Check that weighs_to_device was called for each block + assert mock_weights_to_device.call_count == 4 + + # Check that _synchronize_device and _clean_memory_on_device were called + mock_sync.assert_called_once_with(model_offloader.device) + mock_clean.assert_called_once_with(model_offloader.device) + + +@patch('library.custom_offloading_utils.ModelOffloader._wait_blocks_move') +def test_wait_for_block(mock_wait, model_offloader): + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.wait_for_block(1) + mock_wait.assert_not_called() + + # Test with blocks_to_swap=2 + model_offloader.blocks_to_swap = 2 + block_idx = 1 + model_offloader.wait_for_block(block_idx) + mock_wait.assert_called_once_with(block_idx) + + +@patch('library.custom_offloading_utils.ModelOffloader._submit_move_blocks') +def test_submit_move_blocks(mock_submit, model_offloader): + model = SimpleModel() + blocks = model.blocks + + # Test with blocks_to_swap=0 + model_offloader.blocks_to_swap = 0 + model_offloader.submit_move_blocks(blocks, 1) + mock_submit.assert_not_called() + + mock_submit.reset_mock() + model_offloader.blocks_to_swap = 2 + + # Test within swap range + block_idx = 1 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_called_once() + + mock_submit.reset_mock() + + # Test outside swap range + block_idx = 3 + model_offloader.submit_move_blocks(blocks, block_idx) + mock_submit.assert_not_called() + + +# Integration test for offloading in a realistic scenario +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_offloading_integration(): + device = torch.device('cuda') + # Create a mini model with 4 blocks + model = SimpleModel(5) + model.to(device) + blocks = model.blocks + + # Initialize model offloader + offloader = ModelOffloader( + blocks=blocks, + blocks_to_swap=2, + device=device, + debug=True + ) + + # Prepare blocks for forward pass + offloader.prepare_block_devices_before_forward(blocks) + + # Simulate forward pass with offloading + input_tensor = torch.randn(1, 10, device=device) + x = input_tensor + + for i, block in enumerate(blocks): + # Wait for the current block to be ready + offloader.wait_for_block(i) + + # Process through the block + x = block(x) + + # Schedule moving weights for future blocks + offloader.submit_move_blocks(blocks, i) + + # Verify we get a valid output + assert x.shape == (1, 10) + assert not torch.isnan(x).any() + + +# Error handling tests +def test_offloader_assertion_error(): + with pytest.raises(AssertionError): + device = torch.device('cpu') + layer_to_cpu = SimpleModel() + layer_to_cuda = nn.Linear(10, 5) # Different class + swap_weight_devices_cuda(device, layer_to_cpu, layer_to_cuda) + +if __name__ == "__main__": + # Run all tests when file is executed directly + import sys + + # Configure pytest command line arguments + pytest_args = [ + "-v", # Verbose output + "--color=yes", # Colored output + __file__, # Run tests in this file + ] + + # Add optional arguments from command line + if len(sys.argv) > 1: + pytest_args.extend(sys.argv[1:]) + + # Print info about test execution + print(f"Running tests with PyTorch {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + + # Run the tests + sys.exit(pytest.main(pytest_args)) diff --git a/tests/test_fine_tune.py b/tests/test_fine_tune.py new file mode 100644 index 000000000..fd39ce612 --- /dev/null +++ b/tests/test_fine_tune.py @@ -0,0 +1,6 @@ +import fine_tune + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_flux_train.py b/tests/test_flux_train.py new file mode 100644 index 000000000..2b8739cfc --- /dev/null +++ b/tests/test_flux_train.py @@ -0,0 +1,6 @@ +import flux_train + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_flux_train_network.py b/tests/test_flux_train_network.py new file mode 100644 index 000000000..aaff89624 --- /dev/null +++ b/tests/test_flux_train_network.py @@ -0,0 +1,5 @@ +import flux_train_network + +def test_syntax(): + # Very simply testing that the flux_train_network imports without syntax errors + assert True diff --git a/tests/test_lumina_train_network.py b/tests/test_lumina_train_network.py new file mode 100644 index 000000000..2b8fe21d4 --- /dev/null +++ b/tests/test_lumina_train_network.py @@ -0,0 +1,177 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +import argparse + +from library import lumina_models, lumina_util +from lumina_train_network import LuminaNetworkTrainer + + +@pytest.fixture +def lumina_trainer(): + return LuminaNetworkTrainer() + + +@pytest.fixture +def mock_args(): + args = MagicMock() + args.pretrained_model_name_or_path = "test_path" + args.disable_mmap_load_safetensors = False + args.use_flash_attn = False + args.use_sage_attn = False + args.fp8_base = False + args.blocks_to_swap = None + args.gemma2 = "test_gemma2_path" + args.ae = "test_ae_path" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = False + args.network_train_unet_only = False + return args + + +@pytest.fixture +def mock_accelerator(): + accelerator = MagicMock() + accelerator.device = torch.device("cpu") + accelerator.prepare.side_effect = lambda x, **kwargs: x + accelerator.unwrap_model.side_effect = lambda x: x + return accelerator + + +def test_assert_extra_args(lumina_trainer, mock_args): + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + + # Test with default settings + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # Verify verify_bucket_reso_steps was called for both groups + assert train_dataset_group.verify_bucket_reso_steps.call_count > 0 + assert val_dataset_group.verify_bucket_reso_steps.call_count > 0 + + # Check text encoder output caching + assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only) + assert mock_args.cache_text_encoder_outputs is True + + +def test_load_target_model(lumina_trainer, mock_args, mock_accelerator): + # Patch lumina_util methods + with ( + patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model, + patch("library.lumina_util.load_gemma2") as mock_load_gemma2, + patch("library.lumina_util.load_ae") as mock_load_ae, + ): + # Create mock models + mock_model = MagicMock(spec=lumina_models.NextDiT) + mock_model.dtype = torch.float32 + mock_gemma2 = MagicMock() + mock_ae = MagicMock() + + mock_load_lumina_model.return_value = mock_model + mock_load_gemma2.return_value = mock_gemma2 + mock_load_ae.return_value = mock_ae + + # Test load_target_model + version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator) + + # Verify calls and return values + assert version == lumina_util.MODEL_VERSION_LUMINA_V2 + assert gemma2_list == [mock_gemma2] + assert ae == mock_ae + assert model == mock_model + + # Verify load calls + mock_load_lumina_model.assert_called_once() + mock_load_gemma2.assert_called_once() + mock_load_ae.assert_called_once() + + +def test_get_strategies(lumina_trainer, mock_args): + # Test tokenize strategy + try: + tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args) + assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy" + except OSError as e: + # If the tokenizer is not found (due to gated repo), we can skip the test + print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") + + # Test latents caching strategy + latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args) + assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy" + + # Test text encoding strategy + text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args) + assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy" + + +def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args): + # Call assert_extra_args to set train_gemma2 + train_dataset_group = MagicMock() + train_dataset_group.verify_bucket_reso_steps = MagicMock() + val_dataset_group = MagicMock() + val_dataset_group.verify_bucket_reso_steps = MagicMock() + lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group) + + # With text encoder caching enabled + mock_args.skip_cache_check = False + mock_args.text_encoder_batch_size = 16 + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + + assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy" + assert strategy.cache_to_disk is False # based on mock_args + + # With text encoder caching disabled + mock_args.cache_text_encoder_outputs = False + strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args) + assert strategy is None + + +def test_noise_scheduler(lumina_trainer, mock_args): + device = torch.device("cpu") + noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device) + + assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler" + assert noise_scheduler.num_train_timesteps == 1000 + assert hasattr(lumina_trainer, "noise_scheduler_copy") + + +def test_sai_model_spec(lumina_trainer, mock_args): + with patch("library.train_util.get_sai_model_spec") as mock_get_spec: + mock_get_spec.return_value = "test_spec" + spec = lumina_trainer.get_sai_model_spec(mock_args) + assert spec == "test_spec" + mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2") + + +def test_update_metadata(lumina_trainer, mock_args): + metadata = {} + lumina_trainer.update_metadata(metadata, mock_args) + + assert "ss_weighting_scheme" in metadata + assert "ss_logit_mean" in metadata + assert "ss_logit_std" in metadata + assert "ss_mode_scale" in metadata + assert "ss_timestep_sampling" in metadata + assert "ss_sigmoid_scale" in metadata + assert "ss_model_prediction_type" in metadata + assert "ss_discrete_flow_shift" in metadata + + +def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args): + # Test with text encoder output caching, but not training text encoder + mock_args.cache_text_encoder_outputs = True + with patch.object(lumina_trainer, "is_train_text_encoder", return_value=False): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is True + + # Test with text encoder output caching and training text encoder + with patch.object(lumina_trainer, "is_train_text_encoder", return_value=True): + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False + + # Test with no text encoder output caching + mock_args.cache_text_encoder_outputs = False + result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args) + assert result is False diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 000000000..f6ade91a6 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) diff --git a/tests/test_sd3_train.py b/tests/test_sd3_train.py new file mode 100644 index 000000000..a7c5d27a2 --- /dev/null +++ b/tests/test_sd3_train.py @@ -0,0 +1,6 @@ +import sd3_train + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_sd3_train_network.py b/tests/test_sd3_train_network.py new file mode 100644 index 000000000..10c0795cb --- /dev/null +++ b/tests/test_sd3_train_network.py @@ -0,0 +1,5 @@ +import sd3_train_network + +def test_syntax(): + # Very simply testing that the flux_train_network imports without syntax errors + assert True diff --git a/tests/test_sdxl_train.py b/tests/test_sdxl_train.py new file mode 100644 index 000000000..1c0e85799 --- /dev/null +++ b/tests/test_sdxl_train.py @@ -0,0 +1,6 @@ +import sdxl_train + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_sdxl_train_network.py b/tests/test_sdxl_train_network.py new file mode 100644 index 000000000..58300ae7d --- /dev/null +++ b/tests/test_sdxl_train_network.py @@ -0,0 +1,6 @@ +import sdxl_train_network + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 000000000..51c794924 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,6 @@ +import train_db + + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_train_network.py b/tests/test_train_network.py new file mode 100644 index 000000000..fe17263c6 --- /dev/null +++ b/tests/test_train_network.py @@ -0,0 +1,5 @@ +import train_network + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_train_textual_inversion.py b/tests/test_train_textual_inversion.py new file mode 100644 index 000000000..ab6a93425 --- /dev/null +++ b/tests/test_train_textual_inversion.py @@ -0,0 +1,5 @@ +import train_textual_inversion + +def test_syntax(): + # Very simply testing that the train_network imports without syntax errors + assert True diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 000000000..f80686d8c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,17 @@ +from library.train_util import split_train_val + + +def test_split_train_val(): + paths = ["path1", "path2", "path3", "path4", "path5", "path6", "path7"] + sizes = [(1, 1), (2, 2), None, (4, 4), (5, 5), (6, 6), None] + result_paths, result_sizes = split_train_val(paths, sizes, True, 0.2, 1234) + assert result_paths == ["path2", "path3", "path6", "path5", "path1", "path4"], result_paths + assert result_sizes == [(2, 2), None, (6, 6), (5, 5), (1, 1), (4, 4)], result_sizes + + result_paths, result_sizes = split_train_val(paths, sizes, False, 0.2, 1234) + assert result_paths == ["path7"], result_paths + assert result_sizes == [None], result_sizes + + +if __name__ == "__main__": + test_split_train_val() diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 2f0098b42..5baddb5bf 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,50 +9,83 @@ import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging logger = logging.getLogger(__name__) +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -83,17 +116,12 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") @@ -106,72 +134,27 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - alpha_mask = batch["alpha_mask"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected( - image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask - ): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -179,10 +162,16 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -191,7 +180,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index a75d9da74..8e6042923 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,55 +9,70 @@ import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments +from cache_latents import set_tokenize_strategy + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -88,15 +103,12 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") @@ -105,69 +117,71 @@ def cache_to_disk(args: argparse.Namespace) -> None: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -175,15 +189,33 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", + ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) return parser diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 000000000..9dcd8fed6 --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,150 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library import flux_utils +from library.utils import setup_logging, str_to_dtype +from library.safetensors_utils import MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index d2a4d9cfb..16fd7d0b7 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -170,12 +170,9 @@ def process(args): scale = max(cur_crop_width / w, cur_crop_height / h) if scale != 1.0: - w = int(w * scale + .5) - h = int(h * scale + .5) - if scale < 1.0: - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) - else: - face_img = pil_resize(face_img, (w, h)) + rw = int(w * scale + .5) + rh = int(h * scale + .5) + face_img = resize_image(face_img, w, h, rw, rh) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py new file mode 100644 index 000000000..6ec045ddc --- /dev/null +++ b/tools/merge_sd3_safetensors.py @@ -0,0 +1,167 @@ +import argparse +import os +import gc +from typing import Dict, Optional, Union +import torch +from safetensors.torch import safe_open + +from library.utils import setup_logging +from library.utils import str_to_dtype +from library.safetensors_utils import load_safetensors, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def merge_safetensors( + dit_path: str, + vae_path: Optional[str] = None, + clip_l_path: Optional[str] = None, + clip_g_path: Optional[str] = None, + t5xxl_path: Optional[str] = None, + output_path: str = "merged_model.safetensors", + device: str = "cpu", + save_precision: Optional[str] = None, +): + """ + Merge multiple safetensors files into a single file + + Args: + dit_path: Path to the DiT/MMDiT model + vae_path: Path to the VAE model + clip_l_path: Path to the CLIP-L model + clip_g_path: Path to the CLIP-G model + t5xxl_path: Path to the T5-XXL model + output_path: Path to save the merged model + device: Device to load tensors to + save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16') + """ + logger.info("Starting to merge safetensors files...") + + # Convert save_precision string to torch dtype if specified + if save_precision: + target_dtype = str_to_dtype(save_precision) + else: + target_dtype = None + + # 1. Get DiT metadata if available + metadata = None + try: + with safe_open(dit_path, framework="pt") as f: + metadata = f.metadata() # may be None + if metadata: + logger.info(f"Found metadata in DiT model: {metadata}") + except Exception as e: + logger.warning(f"Failed to read metadata from DiT model: {e}") + + # 2. Create empty merged state dict + merged_state_dict = {} + + # 3. Load and merge each model with memory management + + # DiT/MMDiT - prefix: model.diffusion_model. + # This state dict may have VAE keys. + logger.info(f"Loading DiT model from {dit_path}") + dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype) + logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") + for key, value in dit_state_dict.items(): + if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"model.diffusion_model.{key}"] = value + # Free memory + del dit_state_dict + gc.collect() + + # VAE - prefix: first_stage_model. + # May be omitted if VAE is already included in DiT model. + if vae_path: + logger.info(f"Loading VAE model from {vae_path}") + vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype) + logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") + for key, value in vae_state_dict.items(): + if key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"first_stage_model.{key}"] = value + # Free memory + del vae_state_dict + gc.collect() + + # CLIP-L - prefix: text_encoders.clip_l. + if clip_l_path: + logger.info(f"Loading CLIP-L model from {clip_l_path}") + clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype) + logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") + for key, value in clip_l_state_dict.items(): + if key.startswith("text_encoders.clip_l.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value + # Free memory + del clip_l_state_dict + gc.collect() + + # CLIP-G - prefix: text_encoders.clip_g. + if clip_g_path: + logger.info(f"Loading CLIP-G model from {clip_g_path}") + clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype) + logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") + for key, value in clip_g_state_dict.items(): + if key.startswith("text_encoders.clip_g.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value + # Free memory + del clip_g_state_dict + gc.collect() + + # T5-XXL - prefix: text_encoders.t5xxl. + if t5xxl_path: + logger.info(f"Loading T5-XXL model from {t5xxl_path}") + t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype) + logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") + for key, value in t5xxl_state_dict.items(): + if key.startswith("text_encoders.t5xxl.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value + # Free memory + del t5xxl_state_dict + gc.collect() + + # 4. Save merged state dict + logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total") + mem_eff_save_file(merged_state_dict, output_path, metadata) + logger.info("Successfully merged safetensors files") + + +def main(): + parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") + parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") + parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model") + parser.add_argument("--clip_l", help="Path to the CLIP-L model") + parser.add_argument("--clip_g", help="Path to the CLIP-G model") + parser.add_argument("--t5xxl", help="Path to the T5-XXL model") + parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model") + parser.add_argument("--device", default="cpu", help="Device to load tensors to") + parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)") + + args = parser.parse_args() + + merge_safetensors( + dit_path=args.dit, + vae_path=args.vae, + clip_l_path=args.clip_l, + clip_g_path=args.clip_g, + t5xxl_path=args.t5xxl, + output_path=args.output, + device=args.device, + save_precision=args.save_precision, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 0f9e00b1e..f5fbae2bb 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import math from PIL import Image import numpy as np -from library.utils import setup_logging, pil_resize +from library.utils import setup_logging, resize_image setup_logging() import logging logger = logging.getLogger(__name__) @@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi if not os.path.exists(dst_img_folder): os.makedirs(dst_img_folder) - # Select interpolation method - if interpolation == 'lanczos4': - pil_interpolation = Image.LANCZOS - elif interpolation == 'cubic': - pil_interpolation = Image.BICUBIC - else: - cv2_interpolation = cv2.INTER_AREA - # Iterate through all files in src_img_folder img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py for filename in os.listdir(src_img_folder): @@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_height = int(img.shape[0] * math.sqrt(scale_factor)) new_width = int(img.shape[1] * math.sqrt(scale_factor)) - # Resize image - if cv2_interpolation: - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) - else: - img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) + img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation) else: new_height, new_width = img.shape[0:2] @@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser: help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) - parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], - default='area', help='Interpolation method for resizing / リサイズ時の補完方法') + parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'], + default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。') parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') parser.add_argument('--copy_associated_files', action='store_true', help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') diff --git a/train_control_net.py b/train_control_net.py new file mode 100644 index 000000000..c12693baf --- /dev/null +++ b/train_control_net.py @@ -0,0 +1,673 @@ +import argparse +import json +import math +import os +import random +import time +from multiprocessing import Value + +# from omegaconf import OmegaConf +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils, strategy_base, strategy_sd +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +import library.sai_model_spec as sai_model_spec +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + pyramid_noise_like, + apply_noise_offset, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + # session_id = random.randint(0, 2**32) + # training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer = tokenize_strategy.tokenizer + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(64) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True + ) + + # DiffusersのControlNetが使用するデータを準備する + if args.v2: + unet.config = { + "act_fn": "silu", + "attention_head_dim": [5, 10, 20, 20], + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 1024, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "sample_size": 96, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "use_linear_projection": True, + "upcast_attention": True, + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": True, + "class_embed_type": None, + "num_class_embeds": None, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + else: + unet.config = { + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 768, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": 8, + "out_channels": 4, + "sample_size": 64, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": False, + "class_embed_type": None, + "num_class_embeds": None, + "upcast_attention": False, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + # unet.config = OmegaConf.create(unet.config) + + # make unet.config iterable and accessible by attribute + class CustomConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __contains__(self, name): + return name in self.__dict__ + + unet.config = CustomConfig(**unet.config) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(controlnet.parameters()) + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.to(accelerator.device) + text_encoder.to(accelerator.device) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model( + ckpt_name, + accelerator.unwrap_model(controlnet), + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # end of epoch + if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, controlnet, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_controlnet.py b/train_controlnet.py index 6938c4bcc..365e35c8c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,42 +1,4 @@ -import argparse -import json -import math -import os -import random -import time -from multiprocessing import Value - -# from omegaconf import OmegaConf -import toml - -from tqdm import tqdm - -import torch -from library import deepspeed_utils -from library.device_utils import init_ipex, clean_memory_on_device - -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -from diffusers import DDPMScheduler, ControlNetModel -from safetensors.torch import load_file - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - pyramid_noise_like, - apply_noise_offset, -) -from library.utils import setup_logging, add_logging_arguments +from library.utils import setup_logging setup_logging() import logging @@ -44,601 +6,14 @@ logger = logging.getLogger(__name__) -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - # session_id = random.randint(0, 2**32) - # training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - train_dataset_group.verify_bucket_reso_steps(64) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True - ) - - # DiffusersのControlNetが使用するデータを準備する - if args.v2: - unet.config = { - "act_fn": "silu", - "attention_head_dim": [5, 10, 20, 20], - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 1024, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "mid_block_type": "UNetMidBlock2DCrossAttn", - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_attention_heads": [5, 10, 20, 20], - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "sample_size": 96, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "use_linear_projection": True, - "upcast_attention": True, - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": True, - "class_embed_type": None, - "num_class_embeds": None, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - else: - unet.config = { - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 768, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "mid_block_type": "UNetMidBlock2DCrossAttn", - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_attention_heads": 8, - "out_channels": 4, - "sample_size": 64, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": False, - "class_embed_type": None, - "num_class_embeds": None, - "upcast_attention": False, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - # unet.config = OmegaConf.create(unet.config) - - # make unet.config iterable and accessible by attribute - class CustomConfig: - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - def __getattr__(self, name): - if name in self.__dict__: - return self.__dict__[name] - else: - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - def __contains__(self, name): - return name in self.__dict__ - - unet.config = CustomConfig(**unet.config) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = list(controlnet.parameters()) - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.to(accelerator.device) - text_encoder.to(accelerator.device) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet - ) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device - ) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and (args.save_state or args.save_state_on_train_end): - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - - return parser - +from library import train_util +from train_control_net import setup_parser, train if __name__ == "__main__": + logger.warning( + "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead" + " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。" + ) parser = setup_parser() args = parser.parse_args() diff --git a/train_db.py b/train_db.py index e7cf3cde3..4bf3b31ce 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -22,6 +22,7 @@ import library.train_util as train_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -38,6 +39,7 @@ apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -58,7 +60,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -80,10 +89,11 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -147,13 +157,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -186,8 +200,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -292,10 +309,19 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -332,23 +358,21 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -360,7 +384,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -395,7 +420,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -420,7 +445,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -433,7 +458,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -459,7 +484,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: @@ -486,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_network.py b/train_network.py index 6953bb175..6cebf5fc7 100644 --- a/train_network.py +++ b/train_network.py @@ -1,24 +1,31 @@ +import gc import importlib import argparse import math import os +import typing +from typing import Any, List, Union, Optional import sys import random import time import json from multiprocessing import Value -import toml +import numpy as np from tqdm import tqdm import torch +import torch.nn as nn +from torch.types import Number from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler -from library import deepspeed_utils, model_util +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -59,16 +66,24 @@ def generate_step_logs( avr_loss, lr_scheduler, lr_descriptions, + optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, + mean_grad_norm=None, + mean_combined_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm logs["max_norm/max_key_norm"] = maximum_norm + if mean_norm is not None: + logs["norm/avg_key_norm"] = mean_norm + if mean_grad_norm is not None: + logs["norm/avg_grad_norm"] = mean_grad_norm + if mean_combined_norm is not None: + logs["norm/avg_combined_norm"] = mean_combined_norm lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): @@ -91,39 +106,127 @@ def generate_step_logs( logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs - def assert_extra_args(self, args, train_dataset_group): + def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, global_step, global_step, epoch) + + def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int): + self.accelerator_logging(accelerator, logs, epoch, global_step, epoch) + + def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int): + self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step) + + def accelerator_logging( + self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None + ): + """ + step_value is for tensorboard, other values are for wandb + """ + tensorboard_tracker = None + wandb_tracker = None + other_trackers = [] + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + tensorboard_tracker = accelerator.get_tracker("tensorboard") + elif tracker.name == "wandb": + wandb_tracker = accelerator.get_tracker("wandb") + else: + other_trackers.append(accelerator.get_tracker(tracker.name)) + + if tensorboard_tracker is not None: + tensorboard_tracker.log(logs, step=step_value) + + if wandb_tracker is not None: + logs["global_step"] = global_step + logs["epoch"] = epoch + if val_step is not None: + logs["val_step"] = val_step + wandb_tracker.log(logs) + + for tracker in other_trackers: + tracker.log(logs, step=step_value) + + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) - def load_target_model(self, args, weight_dtype, accelerator): + def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]: + raise NotImplementedError() + + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) - def is_text_encoder_outputs_cached(self, args): - return False + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). + """ + return text_encoders + + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -131,8 +234,259 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + + # region SD/SDXL + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=True, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + + return noise_pred, target, timesteps, None + + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + + # endregion + + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True, + ) -> torch.Tensor: + """ + Process a batch for the network + """ + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + else: + # latentに変換 + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) + else: + chunks = [ + batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) + ] + list_latents = [] + for chunk in chunks: + with torch.no_grad(): + chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) + list_latents.append(chunk) + latents = torch.cat(list_latents, dim=0) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train, + ) + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + return loss.mean() + + def cast_text_encoder(self, args): + return True # default for other than HunyuanImage + + def cast_vae(self, args): + return True # default for other than HunyuanImage + + def cast_unet(self, args): + return True # default for other than HunyuanImage def train(self, args): session_id = random.randint(0, 2**32) @@ -150,9 +504,13 @@ def train(self, args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -194,11 +552,12 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -206,7 +565,12 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -218,8 +582,12 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -228,18 +596,47 @@ def train(self, args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae(args) else None - # モデルを読み込む + # load target models: unet may be None for lazy loading model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + if vae_dtype is None: + vae_dtype = vae.dtype + logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false") # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) + # prepare dataset for latents caching if needed + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) + + if unet is None: + # lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory + unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders) # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) @@ -263,29 +660,11 @@ def train(self, args): accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される - # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) - # prepare network net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: - key, value = net_arg.split("=") + key, value = net_arg.split("=", 1) net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') @@ -310,6 +689,10 @@ def train(self, args): return network_has_multiplier = hasattr(network, "set_multiplier") + # TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work): + # if not hasattr(network, "prepare_network"): + # network.prepare_network = lambda args: None + if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): @@ -318,28 +701,49 @@ def train(self, args): ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: - # FIXME consider alpha of weights + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -347,11 +751,7 @@ def train(self, args): trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -365,8 +765,15 @@ def train(self, args): # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() - # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -379,6 +786,15 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -410,67 +826,87 @@ def train(self, args): unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn + + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + if self.cast_unet(args): + unet.to(dtype=unet_weight_dtype) + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 - if t_enc.device.type != "cpu": + if t_enc.device.type != "cpu" and self.cast_text_encoder(args): t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler + ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler ) training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: - unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype if self.cast_unet(args) else None) if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler + network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, val_dataloader, lr_scheduler ) training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: - t_enc.text_model.embeddings.requires_grad_(True) + if frag: + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -550,6 +986,9 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print( + f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}" + ) accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -566,9 +1005,10 @@ def load_model_hook(models, input_dir): "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -612,9 +1052,20 @@ def load_model_hook(models, input_dir): "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": args.resize_interpolation, } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -636,6 +1087,7 @@ def load_model_hook(models, input_dir): "max_bucket_reso": dataset.max_bucket_reso, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, + "resize_interpolation": dataset.resize_interpolation, } subsets_metadata = [] @@ -653,6 +1105,7 @@ def load_model_hook(models, input_dir): "enable_wildcard": bool(subset.enable_wildcard), "caption_prefix": subset.caption_prefix, "caption_suffix": subset.caption_suffix, + "resize_interpolation": subset.resize_interpolation, } image_dir_or_metadata_file = None @@ -801,10 +1254,6 @@ def load_model_hook(models, input_dir): args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" - progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" - ) - epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: @@ -825,33 +1274,23 @@ def load_model_hook(models, input_dir): global_step = 0 - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -864,7 +1303,7 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) @@ -877,8 +1316,25 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # if text_encoder is not needed for training, delete it to save memory. + # TODO this can be automated after SDXL sample prompt cache is implemented + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + text_encoder = None + gc.collect() + clean_memory_on_device(accelerator.device) + # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + optimizer_eval_fn() + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() + is_tracking = len(accelerator.trackers) > 0 + if is_tracking: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -887,14 +1343,70 @@ def remove_model(old_ckpt_name): initial_step -= len(train_dataloader) global_step = initial_step + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here + # TRAINING skipped_dataloader = None if initial_step > 0: skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) @@ -907,108 +1419,28 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: - with torch.no_grad(): - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) - else: - chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)] - list_latents = [] - for chunk in chunks: - with torch.no_grad(): - # latentに変換 - list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)) - latents = torch.cat(list_latents, dim=0) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor - - # get multiplier for each sample - if network_has_multiplier: - multipliers = batch["network_multipliers"] - # if all multipliers are same, use single multiplier - if torch.all(multipliers == multipliers[0]): - multipliers = multipliers[0].item() - else: - raise NotImplementedError("multipliers for each sample is not supported yet") - # print(f"set multiplier: {multipliers}") - accelerator.unwrap_model(network).set_multiplier(multipliers) - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + on_step_start_for_network(text_encoder, unet) + + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet, ) - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) if accelerator.sync_gradients: @@ -1017,6 +1449,11 @@ def remove_model(old_ckpt_name): params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if hasattr(network, "update_grad_norms"): + network.update_grad_norms() + if hasattr(network, "update_norms"): + network.update_norms() + optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) @@ -1025,16 +1462,36 @@ def remove_model(old_ckpt_name): keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) + mean_grad_norm = None + mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: - keys_scaled, mean_norm, maximum_norm = None, None, None + if hasattr(network, "weight_norms"): + weight_norms = network.weight_norms() + mean_norm = weight_norms.mean().item() if weight_norms is not None else None + grad_norms = network.grad_norms() + mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None + combined_weight_norms = network.combined_weight_norms() + mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None + maximum_norm = weight_norms.max().item() if weight_norms is not None else None + keys_scaled = None + max_mean_logs = {} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {} # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + optimizer_eval_fn() + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) + progress_bar.unpause() # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1050,32 +1507,189 @@ def remove_model(old_ckpt_name): if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: + if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm, + mean_grad_norm, + mean_combined_norm, ) - accelerator.log(logs, step=global_step) + self.step_logging(accelerator, logs, global_step, epoch + 1) + + # VALIDATION PER STEP: global_step is already incremented + # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... + should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) + + val_progress_bar = tqdm( + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", + ) + val_timesteps_step = 0 + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + for timestep in validation_timesteps: + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + # if is_tracking: + # logs = {f"loss/validation/step_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) + + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + val_timesteps_step += 1 + + if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, + } + self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) + + restore_rng_state(rng_states) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + progress_bar.unpause() if global_step >= args.max_train_steps: break - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # EPOCH VALIDATION + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True + ) + + if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) + + val_progress_bar = tqdm( + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", + ) + + val_timesteps_step = 0 + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) + + # if is_tracking: + # logs = {f"loss/validation/epoch_current_{timestep}": current_loss} + # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step) + + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + val_timesteps_step += 1 + + if is_tracking: + avr_loss: float = val_epoch_loss_recorder.moving_average + loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + } + self.epoch_logging(accelerator, logs, global_step, epoch + 1) + + restore_rng_state(rng_states) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + progress_bar.unpause() + + # END OF EPOCH + if is_tracking: + logs = {"loss/epoch_average": loss_recorder.moving_average} + self.epoch_logging(accelerator, logs, global_step, epoch + 1) accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1090,7 +1704,9 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + progress_bar.unpause() + optimizer_train_fn() # end of epoch @@ -1101,6 +1717,7 @@ def remove_model(old_ckpt_name): network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) @@ -1117,6 +1734,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) @@ -1125,6 +1743,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) @@ -1137,7 +1761,19 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" @@ -1233,9 +1869,36 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) - # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") - # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") - # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する", + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合", + ) + parser.add_argument( + "--validate_every_n_steps", + type=int, + default=None, + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます", + ) + parser.add_argument( + "--validate_every_n_epochs", + type=int, + default=None, + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます", + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", + ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 37349da7d..8575698d6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import math import os from multiprocessing import Value +from typing import Any, List, Optional, Union import toml from tqdm import tqdm @@ -15,7 +16,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -98,33 +99,46 @@ def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) + def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet + + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +196,13 @@ def train(self, args): if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +213,7 @@ def train(self, args): vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,12 +322,13 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group, val_dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -368,11 +381,10 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +399,11 @@ def train(self, args): trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +431,8 @@ def train(self, args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +460,9 @@ def train(self, args): else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +517,9 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,11 +549,14 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -568,11 +580,16 @@ def remove_model(old_ckpt_name): latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -588,7 +605,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -639,8 +657,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -672,7 +690,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -690,7 +708,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -722,11 +740,12 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) + accelerator.log({}) # end of epoch @@ -752,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index fac0787b9..778210950 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -21,6 +21,7 @@ import library.train_util as train_util import library.huggingface_util as huggingface_util import library.config_util as config_util +import library.sai_model_spec as sai_model_spec from library.config_util import ( ConfigSanitizer, BlueprintGenerator, @@ -239,7 +240,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -407,7 +408,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -461,7 +464,7 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +476,8 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -538,7 +542,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +560,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -665,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser)