Description
Describe the bug
When using the SD3 ControlNet training script, the training dataset embeddings are precomputed and the results are given a fingerprint based on the input script arguments, which will cause subsequent runs to use the cached preprocessed dataset instead of recomputing the embeddings, which in my experience takes a while. However, arguments that are completely unrelated to the dataset affect this hash, so any minor changes to runtime arguments can trigger a full dataset remap.
The code in question:
diffusers/examples/controlnet/train_controlnet_sd3.py
Lines 1167 to 1183 in ed4efbd
Ideally, the hash should only depend on arguments directly related to the dataset and the embeddings computation function, namely --dataset_config, --dataset_config_name, --pretrained_model_or_path, --variant, and --revision (I could be wrong and other arguments may or may not affect the embeddings, someone please validate).
This issue might affect other example scripts - I haven't looked too deeply at other scripts, but if this technique is used elsewhere, it could be causing similar training startup delays, especially for people using the same dataset over multiple training attempts.
Reproduction
Step 1: Start training an SD3 ControlNet model (specific dataset shouldn't matter, and you can exit the script after the map() is complete and training begins)
python3 examples/controlnet/train_controlnet_sd3.py --pretrained_model_name_or_path=/path/to/your/stable-diffusion-3-medium-diffusers --output_dir=/path/to/output --dataset_name=fusing/fill50k --resolution=1024 --learning_rate=1e-5 --train_batch_size=2 --dataset_preprocess_batch_size=500
Step 2: relaunch with the same arguments. Observe that the map() is skipped and training restarts fairly quickly.
Step 3: relaunch with --train_batch_size reduced to 1. Observe that the map() is restarted.
Proposed fix
For SD3 ControlNet script: either create a copy of the arguments and remove arguments that do not affect the input arguments, or create a new argparse.Namespace() and copy the arguments that matter.
Option 1 would look something like:
import copy
args_copy = copy.deepcopy(args)
for unwanted_arg in ['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
'max_train_steps', 'checkpointing_steps', 'lr_num_cycles',
'validation_prompt', 'validation_image', 'validation_steps']:
if hasattr(args_copy, unwanted_arg):
delattr(args_copy, unwanted_arg)
new_fingerprint = Hasher.hash(args_copy)
Option 2 would look something like:
args_copy = argparse.Namespace()
for dataset_arg in ['dataset_config_name', 'pretrained_model_name_or_path', 'variant', 'revision']:
setattr(args_copy, dataset_arg, getattr(args, dataset_arg))
new_fingerprint = Hasher.hash(args_copy)
Ideally, if the dataset is loaded from a local config, the contents of the config should be hashed instead of the filename itself, but fixing that would be a bit more complex (and might be beyond the scope of the example script).
System Info
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.3
- Transformers version: 4.50.0
- Accelerate version: 1.5.2
- PEFT version: 0.11.1
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA RTX A6000, 49140 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Tested with both single instance and accelerate with 2+ GPUs