Skip to content

Commit 82e0f5b

Browse files
authored
Merge branch 'main' into docker-build-workflow
2 parents fb08f3c + b9e1c30 commit 82e0f5b

33 files changed

+690
-154
lines changed

.github/workflows/push_tests.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ env:
2121
jobs:
2222
setup_torch_cuda_pipeline_matrix:
2323
name: Setup Torch Pipelines CUDA Slow Tests Matrix
24-
runs-on: docker-gpu
24+
runs-on: [single-gpu, nvidia-gpu, t4, ci]
2525
container:
2626
image: diffusers/diffusers-pytorch-cpu # this is a CPU image, but we need it to fetch the matrix
2727
options: --shm-size "16gb" --ipc host
@@ -62,7 +62,6 @@ jobs:
6262
needs: setup_torch_cuda_pipeline_matrix
6363
strategy:
6464
fail-fast: false
65-
max-parallel: 1
6665
matrix:
6766
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
6867
runs-on: [single-gpu, nvidia-gpu, t4, ci]

benchmarks/base_classes.py

+29
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,35 @@ def run_inference(self, pipe, args):
236236
)
237237

238238

239+
class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
240+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
241+
image = load_image(url)
242+
243+
def __init__(self, args):
244+
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
245+
pipe.load_ip_adapter(
246+
args.ip_adapter_id[0],
247+
subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
248+
weight_name=args.ip_adapter_id[1],
249+
)
250+
251+
if args.run_compile:
252+
pipe.unet.to(memory_format=torch.channels_last)
253+
print("Run torch compile")
254+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
255+
256+
pipe.set_progress_bar_config(disable=True)
257+
self.pipe = pipe
258+
259+
def run_inference(self, pipe, args):
260+
_ = pipe(
261+
prompt=PROMPT,
262+
ip_adapter_image=self.image,
263+
num_inference_steps=args.num_inference_steps,
264+
num_images_per_prompt=args.batch_size,
265+
)
266+
267+
239268
class ControlNetBenchmark(TextToImageBenchmark):
240269
pipeline_class = StableDiffusionControlNetPipeline
241270
aux_network_class = ControlNetModel

benchmarks/benchmark_ip_adapters.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import argparse
2+
import sys
3+
4+
5+
sys.path.append(".")
6+
from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
7+
8+
9+
IP_ADAPTER_CKPTS = {
10+
"runwayml/stable-diffusion-v1-5": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
11+
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
12+
}
13+
14+
15+
if __name__ == "__main__":
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"--ckpt",
19+
type=str,
20+
default="runwayml/stable-diffusion-v1-5",
21+
choices=list(IP_ADAPTER_CKPTS.keys()),
22+
)
23+
parser.add_argument("--batch_size", type=int, default=1)
24+
parser.add_argument("--num_inference_steps", type=int, default=50)
25+
parser.add_argument("--model_cpu_offload", action="store_true")
26+
parser.add_argument("--run_compile", action="store_true")
27+
args = parser.parse_args()
28+
29+
args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
30+
benchmark_pipe = IPAdapterTextToImageBenchmark(args)
31+
args.ckpt = f"{args.ckpt} (IP-Adapter)"
32+
benchmark_pipe.benchmark(args)

benchmarks/run_all.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main():
7272
command += " --run_compile"
7373
run_command(command.split())
7474

75-
elif file == "benchmark_sd_inpainting.py":
75+
elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
7676
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
7777
command = f"python {file} --ckpt {sdxl_ckpt}"
7878
run_command(command.split())

docs/source/en/tutorials/using_peft_for_inference.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ list_adapters_component_wise
169169

170170
If you want to compile your model with `torch.compile` make sure to first fuse the LoRA weights into the base model and unload them.
171171

172-
```py
172+
```diff
173173
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
174174
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
175175

@@ -178,12 +178,16 @@ pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
178178
pipe.fuse_lora()
179179
pipe.unload_lora_weights()
180180

181-
pipe = torch.compile(pipe)
181+
+ pipe.unet.to(memory_format=torch.channels_last)
182+
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
182183

183184
prompt = "toy_face of a hacker with a hoodie, pixel art"
184185
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
185186
```
186187

188+
> [!TIP]
189+
> You can refer to the `torch.compile()` section [here](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0#torchcompile) and [here](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) for more elaborate examples.
190+
187191
## Fusing adapters into the model
188192

189193
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.

examples/dreambooth/README_sdxl.md

+37
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin
206206
## Running on a free-tier Colab Notebook
207207

208208
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
209+
210+
## Conducting EDM-style training
211+
212+
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
213+
214+
For the SDXL model, simple set:
215+
216+
```diff
217+
+ --do_edm_style_training \
218+
```
219+
220+
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
221+
222+
```bash
223+
accelerate launch train_dreambooth_lora_sdxl.py \
224+
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
225+
--instance_data_dir="dog" \
226+
--output_dir="dog-playground-lora" \
227+
--mixed_precision="fp16" \
228+
--instance_prompt="a photo of sks dog" \
229+
--resolution=1024 \
230+
--train_batch_size=1 \
231+
--gradient_accumulation_steps=4 \
232+
--learning_rate=1e-4 \
233+
--use_8bit_adam \
234+
--report_to="wandb" \
235+
--lr_scheduler="constant" \
236+
--lr_warmup_steps=0 \
237+
--max_train_steps=500 \
238+
--validation_prompt="A photo of sks dog in a bucket" \
239+
--validation_epochs=25 \
240+
--seed="0" \
241+
--push_to_hub
242+
```
243+
244+
> [!CAUTION]
245+
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
36+
def test_dreambooth_lora_sdxl_with_edm(self):
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
test_args = f"""
39+
examples/dreambooth/train_dreambooth_lora_sdxl.py
40+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
41+
--do_edm_style_training
42+
--instance_data_dir docs/source/en/imgs
43+
--instance_prompt photo
44+
--resolution 64
45+
--train_batch_size 1
46+
--gradient_accumulation_steps 1
47+
--max_train_steps 2
48+
--learning_rate 5.0e-04
49+
--scale_lr
50+
--lr_scheduler constant
51+
--lr_warmup_steps 0
52+
--output_dir {tmpdir}
53+
""".split()
54+
55+
run_command(self._launch_args + test_args)
56+
# save_pretrained smoke test
57+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
58+
59+
# make sure the state_dict has the correct naming in the parameters.
60+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
61+
is_lora = all("lora" in k for k in lora_state_dict.keys())
62+
self.assertTrue(is_lora)
63+
64+
# when not training the text encoder, all the parameters in the state dict should start
65+
# with `"unet"` in their names.
66+
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
67+
self.assertTrue(starts_with_unet)
68+
69+
def test_dreambooth_lora_playground(self):
70+
with tempfile.TemporaryDirectory() as tmpdir:
71+
test_args = f"""
72+
examples/dreambooth/train_dreambooth_lora_sdxl.py
73+
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
74+
--instance_data_dir docs/source/en/imgs
75+
--instance_prompt photo
76+
--resolution 64
77+
--train_batch_size 1
78+
--gradient_accumulation_steps 1
79+
--max_train_steps 2
80+
--learning_rate 5.0e-04
81+
--scale_lr
82+
--lr_scheduler constant
83+
--lr_warmup_steps 0
84+
--output_dir {tmpdir}
85+
""".split()
86+
87+
run_command(self._launch_args + test_args)
88+
# save_pretrained smoke test
89+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
90+
91+
# make sure the state_dict has the correct naming in the parameters.
92+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
93+
is_lora = all("lora" in k for k in lora_state_dict.keys())
94+
self.assertTrue(is_lora)
95+
96+
# when not training the text encoder, all the parameters in the state dict should start
97+
# with `"unet"` in their names.
98+
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
99+
self.assertTrue(starts_with_unet)

0 commit comments

Comments
 (0)