Skip to content

SD3 ControlNet Script (and others?): dataset preprocessing cache depends on unrelated arguments #11497

Open
@kentdan3msu

Description

@kentdan3msu

Describe the bug

When using the SD3 ControlNet training script, the training dataset embeddings are precomputed and the results are given a fingerprint based on the input script arguments, which will cause subsequent runs to use the cached preprocessed dataset instead of recomputing the embeddings, which in my experience takes a while. However, arguments that are completely unrelated to the dataset affect this hash, so any minor changes to runtime arguments can trigger a full dataset remap.

The code in question:

compute_embeddings_fn = functools.partial(
compute_text_embeddings,
text_encoders=text_encoders,
tokenizers=tokenizers,
)
with accelerator.main_process_first():
from datasets.fingerprint import Hasher
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(
compute_embeddings_fn,
batched=True,
batch_size=args.dataset_preprocess_batch_size,
new_fingerprint=new_fingerprint,
)

Ideally, the hash should only depend on arguments directly related to the dataset and the embeddings computation function, namely --dataset_config, --dataset_config_name, --pretrained_model_or_path, --variant, and --revision (I could be wrong and other arguments may or may not affect the embeddings, someone please validate).

This issue might affect other example scripts - I haven't looked too deeply at other scripts, but if this technique is used elsewhere, it could be causing similar training startup delays, especially for people using the same dataset over multiple training attempts.

Reproduction

Step 1: Start training an SD3 ControlNet model (specific dataset shouldn't matter, and you can exit the script after the map() is complete and training begins)

python3 examples/controlnet/train_controlnet_sd3.py --pretrained_model_name_or_path=/path/to/your/stable-diffusion-3-medium-diffusers --output_dir=/path/to/output --dataset_name=fusing/fill50k --resolution=1024 --learning_rate=1e-5 --train_batch_size=2 --dataset_preprocess_batch_size=500

Step 2: relaunch with the same arguments. Observe that the map() is skipped and training restarts fairly quickly.

Step 3: relaunch with --train_batch_size reduced to 1. Observe that the map() is restarted.

Proposed fix

For SD3 ControlNet script: either create a copy of the arguments and remove arguments that do not affect the input arguments, or create a new argparse.Namespace() and copy the arguments that matter.

Option 1 would look something like:

import copy
args_copy = copy.deepcopy(args)
for unwanted_arg in ['output_dir', 'train_batch_size', 'dataset_preprocess_batch_size',
                     'gradient_accumulation_steps', 'gradient_checkpointing', 'learning_rate',
                     'max_train_steps', 'checkpointing_steps', 'lr_num_cycles',
                     'validation_prompt', 'validation_image', 'validation_steps']:
    if hasattr(args_copy, unwanted_arg):
        delattr(args_copy, unwanted_arg)
new_fingerprint = Hasher.hash(args_copy)

Option 2 would look something like:

args_copy = argparse.Namespace()
for dataset_arg in ['dataset_config_name', 'pretrained_model_name_or_path', 'variant', 'revision']:
    setattr(args_copy, dataset_arg, getattr(args, dataset_arg))
new_fingerprint = Hasher.hash(args_copy)

Ideally, if the dataset is loaded from a local config, the contents of the config should be hashed instead of the filename itself, but fixing that would be a bit more complex (and might be beyond the scope of the example script).

System Info

  • 🤗 Diffusers version: 0.34.0.dev0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.50.0
  • Accelerate version: 1.5.2
  • PEFT version: 0.11.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A6000, 49140 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Tested with both single instance and accelerate with 2+ GPUs

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions