Skip to content

Commit 4a307fe

Browse files
authored
Merge branch 'main' into docker-build-workflow
2 parents 6426355 + 56b6845 commit 4a307fe

File tree

2 files changed

+191
-20
lines changed

2 files changed

+191
-20
lines changed

examples/wuerstchen/text_to_image/requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ accelerate>=0.16.0
22
torchvision
33
transformers>=4.25.1
44
wandb
5-
huggingface-cli
65
bitsandbytes
76
deepspeed
87
peft>=0.6.0

src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py

+191-19
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,66 @@
5959
"""
6060

6161

62+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
63+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
64+
"""
65+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
66+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
67+
"""
68+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
69+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
70+
# rescale the results from guidance (fixes overexposure)
71+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
72+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
73+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
74+
return noise_cfg
75+
76+
77+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
78+
def retrieve_timesteps(
79+
scheduler,
80+
num_inference_steps: Optional[int] = None,
81+
device: Optional[Union[str, torch.device]] = None,
82+
timesteps: Optional[List[int]] = None,
83+
**kwargs,
84+
):
85+
"""
86+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
87+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
88+
89+
Args:
90+
scheduler (`SchedulerMixin`):
91+
The scheduler to get timesteps from.
92+
num_inference_steps (`int`):
93+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
94+
`timesteps` must be `None`.
95+
device (`str` or `torch.device`, *optional*):
96+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
97+
timesteps (`List[int]`, *optional*):
98+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
99+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
100+
must be `None`.
101+
102+
Returns:
103+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104+
second element is the number of inference steps.
105+
"""
106+
if timesteps is not None:
107+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
108+
if not accepts_timesteps:
109+
raise ValueError(
110+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
111+
f" timestep schedules. Please check whether you are using the correct scheduler."
112+
)
113+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
114+
timesteps = scheduler.timesteps
115+
num_inference_steps = len(timesteps)
116+
else:
117+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
118+
timesteps = scheduler.timesteps
119+
return timesteps, num_inference_steps
120+
121+
62122
@dataclass
63123
class LDM3DPipelineOutput(BaseOutput):
64124
"""
@@ -125,6 +185,7 @@ class StableDiffusionLDM3DPipeline(
125185
model_cpu_offload_seq = "text_encoder->unet->vae"
126186
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
127187
_exclude_from_cpu_offload = ["safety_checker"]
188+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
128189

129190
def __init__(
130191
self,
@@ -582,6 +643,66 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
582643
latents = latents * self.scheduler.init_noise_sigma
583644
return latents
584645

646+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
647+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
648+
"""
649+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
650+
651+
Args:
652+
timesteps (`torch.Tensor`):
653+
generate embedding vectors at these timesteps
654+
embedding_dim (`int`, *optional*, defaults to 512):
655+
dimension of the embeddings to generate
656+
dtype:
657+
data type of the generated embeddings
658+
659+
Returns:
660+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
661+
"""
662+
assert len(w.shape) == 1
663+
w = w * 1000.0
664+
665+
half_dim = embedding_dim // 2
666+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
667+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
668+
emb = w.to(dtype)[:, None] * emb[None, :]
669+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
670+
if embedding_dim % 2 == 1: # zero pad
671+
emb = torch.nn.functional.pad(emb, (0, 1))
672+
assert emb.shape == (w.shape[0], embedding_dim)
673+
return emb
674+
675+
@property
676+
def guidance_scale(self):
677+
return self._guidance_scale
678+
679+
@property
680+
def guidance_rescale(self):
681+
return self._guidance_rescale
682+
683+
@property
684+
def clip_skip(self):
685+
return self._clip_skip
686+
687+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
688+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
689+
# corresponds to doing no classifier free guidance.
690+
@property
691+
def do_classifier_free_guidance(self):
692+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
693+
694+
@property
695+
def cross_attention_kwargs(self):
696+
return self._cross_attention_kwargs
697+
698+
@property
699+
def num_timesteps(self):
700+
return self._num_timesteps
701+
702+
@property
703+
def interrupt(self):
704+
return self._interrupt
705+
585706
@torch.no_grad()
586707
@replace_example_docstring(EXAMPLE_DOC_STRING)
587708
def __call__(
@@ -590,6 +711,7 @@ def __call__(
590711
height: Optional[int] = None,
591712
width: Optional[int] = None,
592713
num_inference_steps: int = 49,
714+
timesteps: List[int] = None,
593715
guidance_scale: float = 5.0,
594716
negative_prompt: Optional[Union[str, List[str]]] = None,
595717
num_images_per_prompt: Optional[int] = 1,
@@ -602,10 +724,12 @@ def __call__(
602724
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
603725
output_type: Optional[str] = "pil",
604726
return_dict: bool = True,
605-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
606-
callback_steps: int = 1,
607727
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
728+
guidance_rescale: float = 0.0,
608729
clip_skip: Optional[int] = None,
730+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
731+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
732+
**kwargs,
609733
):
610734
r"""
611735
The call function to the pipeline for generation.
@@ -656,18 +780,21 @@ def __call__(
656780
return_dict (`bool`, *optional*, defaults to `True`):
657781
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
658782
plain tuple.
659-
callback (`Callable`, *optional*):
660-
A function that calls every `callback_steps` steps during inference. The function is called with the
661-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
662-
callback_steps (`int`, *optional*, defaults to 1):
663-
The frequency at which the `callback` function is called. If not specified, the callback is called at
664-
every step.
665783
cross_attention_kwargs (`dict`, *optional*):
666784
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
667785
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
668786
clip_skip (`int`, *optional*):
669787
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
670788
the output of the pre-final layer will be used for computing the prompt embeddings.
789+
callback_on_step_end (`Callable`, *optional*):
790+
A function that calls at the end of each denoising steps during the inference. The function is called
791+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
792+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
793+
`callback_on_step_end_tensor_inputs`.
794+
callback_on_step_end_tensor_inputs (`List`, *optional*):
795+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
796+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
797+
`._callback_tensor_inputs` attribute of your pipeline class.
671798
Examples:
672799
673800
Returns:
@@ -677,6 +804,22 @@ def __call__(
677804
second element is a list of `bool`s indicating whether the corresponding generated image contains
678805
"not-safe-for-work" (nsfw) content.
679806
"""
807+
callback = kwargs.pop("callback", None)
808+
callback_steps = kwargs.pop("callback_steps", None)
809+
810+
if callback is not None:
811+
deprecate(
812+
"callback",
813+
"1.0.0",
814+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
815+
)
816+
if callback_steps is not None:
817+
deprecate(
818+
"callback_steps",
819+
"1.0.0",
820+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
821+
)
822+
680823
# 0. Default height and width to unet
681824
height = height or self.unet.config.sample_size * self.vae_scale_factor
682825
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -692,8 +835,15 @@ def __call__(
692835
negative_prompt_embeds,
693836
ip_adapter_image,
694837
ip_adapter_image_embeds,
838+
callback_on_step_end_tensor_inputs,
695839
)
696840

841+
self._guidance_scale = guidance_scale
842+
self._guidance_rescale = guidance_rescale
843+
self._clip_skip = clip_skip
844+
self._cross_attention_kwargs = cross_attention_kwargs
845+
self._interrupt = False
846+
697847
# 2. Define call parameters
698848
if prompt is not None and isinstance(prompt, str):
699849
batch_size = 1
@@ -703,26 +853,22 @@ def __call__(
703853
batch_size = prompt_embeds.shape[0]
704854

705855
device = self._execution_device
706-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
707-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
708-
# corresponds to doing no classifier free guidance.
709-
do_classifier_free_guidance = guidance_scale > 1.0
710856

711857
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
712858
image_embeds = self.prepare_ip_adapter_image_embeds(
713859
ip_adapter_image,
714860
ip_adapter_image_embeds,
715861
device,
716862
batch_size * num_images_per_prompt,
717-
do_classifier_free_guidance,
863+
self.do_classifier_free_guidance,
718864
)
719865

720866
# 3. Encode input prompt
721867
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
722868
prompt,
723869
device,
724870
num_images_per_prompt,
725-
do_classifier_free_guidance,
871+
self.do_classifier_free_guidance,
726872
negative_prompt,
727873
prompt_embeds=prompt_embeds,
728874
negative_prompt_embeds=negative_prompt_embeds,
@@ -731,12 +877,11 @@ def __call__(
731877
# For classifier free guidance, we need to do two forward passes.
732878
# Here we concatenate the unconditional and text embeddings into a single batch
733879
# to avoid doing two forward passes
734-
if do_classifier_free_guidance:
880+
if self.do_classifier_free_guidance:
735881
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
736882

737883
# 4. Prepare timesteps
738-
self.scheduler.set_timesteps(num_inference_steps, device=device)
739-
timesteps = self.scheduler.timesteps
884+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
740885

741886
# 5. Prepare latent variables
742887
num_channels_latents = self.unet.config.in_channels
@@ -757,32 +902,59 @@ def __call__(
757902
# 6.1 Add image embeds for IP-Adapter
758903
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
759904

905+
# 6.2 Optionally get Guidance Scale Embedding
906+
timestep_cond = None
907+
if self.unet.config.time_cond_proj_dim is not None:
908+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
909+
timestep_cond = self.get_guidance_scale_embedding(
910+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
911+
).to(device=device, dtype=latents.dtype)
912+
760913
# 7. Denoising loop
761914
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
915+
self._num_timesteps = len(timesteps)
762916
with self.progress_bar(total=num_inference_steps) as progress_bar:
763917
for i, t in enumerate(timesteps):
918+
if self.interrupt:
919+
continue
920+
764921
# expand the latents if we are doing classifier free guidance
765-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
922+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
766923
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
767924

768925
# predict the noise residual
769926
noise_pred = self.unet(
770927
latent_model_input,
771928
t,
772929
encoder_hidden_states=prompt_embeds,
930+
timestep_cond=timestep_cond,
773931
cross_attention_kwargs=cross_attention_kwargs,
774932
added_cond_kwargs=added_cond_kwargs,
775933
return_dict=False,
776934
)[0]
777935

778936
# perform guidance
779-
if do_classifier_free_guidance:
937+
if self.do_classifier_free_guidance:
780938
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
781939
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
782940

941+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
942+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
943+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
944+
783945
# compute the previous noisy sample x_t -> x_t-1
784946
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
785947

948+
if callback_on_step_end is not None:
949+
callback_kwargs = {}
950+
for k in callback_on_step_end_tensor_inputs:
951+
callback_kwargs[k] = locals()[k]
952+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
953+
954+
latents = callback_outputs.pop("latents", latents)
955+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
956+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
957+
786958
# call the callback, if provided
787959
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
788960
progress_bar.update()

0 commit comments

Comments
 (0)