Description
Describe the bug
I tryied to run the example code of FLUX.1-Canny-dev-lora from https://huggingface.co/docs/diffusers/v0.33.1/en/api/pipelines/flux#canny-control, but get error:
RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel:
size mismatch for proj_out.lora_A.default_0.weight: copying a param with shape torch.Size([64, 3072]) from checkpoint, the shape in current model is torch.Size([128, 3072]).
size mismatch for proj_out.lora_B.default_0.weight: copying a param with shape torch.Size([64, 64]) from checkpoint, the shape in current model is torch.Size([64, 128]).
I checked the code inside the pipeline. It concats noisy tokens and condition tokens in the channel dimension, which change input shape from [batch, token_length, 64] to [batch, token_length, 128].
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_control.py#L830-L831)
Therefore, the LoRA parameters of the first layer are inconsistent with the basic Flux model.
Reproduction
# !pip install -U controlnet-aux
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30.0,
).images[0]
image.save("output.png")
Logs
System Info
- 🤗 Diffusers version: 0.33.1
- Platform: Linux-5.10.134-13.an8.x86_64-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.16
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.2
- Transformers version: 4.43.3
- Accelerate version: 0.30.1
- PEFT version: 0.14.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA L40S, 46068 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response