59
59
"""
60
60
61
61
62
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
63
+ def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
64
+ """
65
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
66
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
67
+ """
68
+ std_text = noise_pred_text .std (dim = list (range (1 , noise_pred_text .ndim )), keepdim = True )
69
+ std_cfg = noise_cfg .std (dim = list (range (1 , noise_cfg .ndim )), keepdim = True )
70
+ # rescale the results from guidance (fixes overexposure)
71
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg )
72
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
73
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale ) * noise_cfg
74
+ return noise_cfg
75
+
76
+
77
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
78
+ def retrieve_timesteps (
79
+ scheduler ,
80
+ num_inference_steps : Optional [int ] = None ,
81
+ device : Optional [Union [str , torch .device ]] = None ,
82
+ timesteps : Optional [List [int ]] = None ,
83
+ ** kwargs ,
84
+ ):
85
+ """
86
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
87
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
88
+
89
+ Args:
90
+ scheduler (`SchedulerMixin`):
91
+ The scheduler to get timesteps from.
92
+ num_inference_steps (`int`):
93
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
94
+ `timesteps` must be `None`.
95
+ device (`str` or `torch.device`, *optional*):
96
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
97
+ timesteps (`List[int]`, *optional*):
98
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
99
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
100
+ must be `None`.
101
+
102
+ Returns:
103
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104
+ second element is the number of inference steps.
105
+ """
106
+ if timesteps is not None :
107
+ accepts_timesteps = "timesteps" in set (inspect .signature (scheduler .set_timesteps ).parameters .keys ())
108
+ if not accepts_timesteps :
109
+ raise ValueError (
110
+ f"The current scheduler class { scheduler .__class__ } 's `set_timesteps` does not support custom"
111
+ f" timestep schedules. Please check whether you are using the correct scheduler."
112
+ )
113
+ scheduler .set_timesteps (timesteps = timesteps , device = device , ** kwargs )
114
+ timesteps = scheduler .timesteps
115
+ num_inference_steps = len (timesteps )
116
+ else :
117
+ scheduler .set_timesteps (num_inference_steps , device = device , ** kwargs )
118
+ timesteps = scheduler .timesteps
119
+ return timesteps , num_inference_steps
120
+
121
+
62
122
@dataclass
63
123
class LDM3DPipelineOutput (BaseOutput ):
64
124
"""
@@ -125,6 +185,7 @@ class StableDiffusionLDM3DPipeline(
125
185
model_cpu_offload_seq = "text_encoder->unet->vae"
126
186
_optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
127
187
_exclude_from_cpu_offload = ["safety_checker" ]
188
+ _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
128
189
129
190
def __init__ (
130
191
self ,
@@ -582,6 +643,66 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
582
643
latents = latents * self .scheduler .init_noise_sigma
583
644
return latents
584
645
646
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
647
+ def get_guidance_scale_embedding (self , w , embedding_dim = 512 , dtype = torch .float32 ):
648
+ """
649
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
650
+
651
+ Args:
652
+ timesteps (`torch.Tensor`):
653
+ generate embedding vectors at these timesteps
654
+ embedding_dim (`int`, *optional*, defaults to 512):
655
+ dimension of the embeddings to generate
656
+ dtype:
657
+ data type of the generated embeddings
658
+
659
+ Returns:
660
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
661
+ """
662
+ assert len (w .shape ) == 1
663
+ w = w * 1000.0
664
+
665
+ half_dim = embedding_dim // 2
666
+ emb = torch .log (torch .tensor (10000.0 )) / (half_dim - 1 )
667
+ emb = torch .exp (torch .arange (half_dim , dtype = dtype ) * - emb )
668
+ emb = w .to (dtype )[:, None ] * emb [None , :]
669
+ emb = torch .cat ([torch .sin (emb ), torch .cos (emb )], dim = 1 )
670
+ if embedding_dim % 2 == 1 : # zero pad
671
+ emb = torch .nn .functional .pad (emb , (0 , 1 ))
672
+ assert emb .shape == (w .shape [0 ], embedding_dim )
673
+ return emb
674
+
675
+ @property
676
+ def guidance_scale (self ):
677
+ return self ._guidance_scale
678
+
679
+ @property
680
+ def guidance_rescale (self ):
681
+ return self ._guidance_rescale
682
+
683
+ @property
684
+ def clip_skip (self ):
685
+ return self ._clip_skip
686
+
687
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
688
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
689
+ # corresponds to doing no classifier free guidance.
690
+ @property
691
+ def do_classifier_free_guidance (self ):
692
+ return self ._guidance_scale > 1 and self .unet .config .time_cond_proj_dim is None
693
+
694
+ @property
695
+ def cross_attention_kwargs (self ):
696
+ return self ._cross_attention_kwargs
697
+
698
+ @property
699
+ def num_timesteps (self ):
700
+ return self ._num_timesteps
701
+
702
+ @property
703
+ def interrupt (self ):
704
+ return self ._interrupt
705
+
585
706
@torch .no_grad ()
586
707
@replace_example_docstring (EXAMPLE_DOC_STRING )
587
708
def __call__ (
@@ -590,6 +711,7 @@ def __call__(
590
711
height : Optional [int ] = None ,
591
712
width : Optional [int ] = None ,
592
713
num_inference_steps : int = 49 ,
714
+ timesteps : List [int ] = None ,
593
715
guidance_scale : float = 5.0 ,
594
716
negative_prompt : Optional [Union [str , List [str ]]] = None ,
595
717
num_images_per_prompt : Optional [int ] = 1 ,
@@ -602,10 +724,12 @@ def __call__(
602
724
ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
603
725
output_type : Optional [str ] = "pil" ,
604
726
return_dict : bool = True ,
605
- callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
606
- callback_steps : int = 1 ,
607
727
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
728
+ guidance_rescale : float = 0.0 ,
608
729
clip_skip : Optional [int ] = None ,
730
+ callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
731
+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
732
+ ** kwargs ,
609
733
):
610
734
r"""
611
735
The call function to the pipeline for generation.
@@ -656,18 +780,21 @@ def __call__(
656
780
return_dict (`bool`, *optional*, defaults to `True`):
657
781
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
658
782
plain tuple.
659
- callback (`Callable`, *optional*):
660
- A function that calls every `callback_steps` steps during inference. The function is called with the
661
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
662
- callback_steps (`int`, *optional*, defaults to 1):
663
- The frequency at which the `callback` function is called. If not specified, the callback is called at
664
- every step.
665
783
cross_attention_kwargs (`dict`, *optional*):
666
784
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
667
785
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
668
786
clip_skip (`int`, *optional*):
669
787
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
670
788
the output of the pre-final layer will be used for computing the prompt embeddings.
789
+ callback_on_step_end (`Callable`, *optional*):
790
+ A function that calls at the end of each denoising steps during the inference. The function is called
791
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
792
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
793
+ `callback_on_step_end_tensor_inputs`.
794
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
795
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
796
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
797
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
798
Examples:
672
799
673
800
Returns:
@@ -677,6 +804,22 @@ def __call__(
677
804
second element is a list of `bool`s indicating whether the corresponding generated image contains
678
805
"not-safe-for-work" (nsfw) content.
679
806
"""
807
+ callback = kwargs .pop ("callback" , None )
808
+ callback_steps = kwargs .pop ("callback_steps" , None )
809
+
810
+ if callback is not None :
811
+ deprecate (
812
+ "callback" ,
813
+ "1.0.0" ,
814
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
815
+ )
816
+ if callback_steps is not None :
817
+ deprecate (
818
+ "callback_steps" ,
819
+ "1.0.0" ,
820
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
821
+ )
822
+
680
823
# 0. Default height and width to unet
681
824
height = height or self .unet .config .sample_size * self .vae_scale_factor
682
825
width = width or self .unet .config .sample_size * self .vae_scale_factor
@@ -692,8 +835,15 @@ def __call__(
692
835
negative_prompt_embeds ,
693
836
ip_adapter_image ,
694
837
ip_adapter_image_embeds ,
838
+ callback_on_step_end_tensor_inputs ,
695
839
)
696
840
841
+ self ._guidance_scale = guidance_scale
842
+ self ._guidance_rescale = guidance_rescale
843
+ self ._clip_skip = clip_skip
844
+ self ._cross_attention_kwargs = cross_attention_kwargs
845
+ self ._interrupt = False
846
+
697
847
# 2. Define call parameters
698
848
if prompt is not None and isinstance (prompt , str ):
699
849
batch_size = 1
@@ -703,26 +853,22 @@ def __call__(
703
853
batch_size = prompt_embeds .shape [0 ]
704
854
705
855
device = self ._execution_device
706
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
707
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
708
- # corresponds to doing no classifier free guidance.
709
- do_classifier_free_guidance = guidance_scale > 1.0
710
856
711
857
if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
712
858
image_embeds = self .prepare_ip_adapter_image_embeds (
713
859
ip_adapter_image ,
714
860
ip_adapter_image_embeds ,
715
861
device ,
716
862
batch_size * num_images_per_prompt ,
717
- do_classifier_free_guidance ,
863
+ self . do_classifier_free_guidance ,
718
864
)
719
865
720
866
# 3. Encode input prompt
721
867
prompt_embeds , negative_prompt_embeds = self .encode_prompt (
722
868
prompt ,
723
869
device ,
724
870
num_images_per_prompt ,
725
- do_classifier_free_guidance ,
871
+ self . do_classifier_free_guidance ,
726
872
negative_prompt ,
727
873
prompt_embeds = prompt_embeds ,
728
874
negative_prompt_embeds = negative_prompt_embeds ,
@@ -731,12 +877,11 @@ def __call__(
731
877
# For classifier free guidance, we need to do two forward passes.
732
878
# Here we concatenate the unconditional and text embeddings into a single batch
733
879
# to avoid doing two forward passes
734
- if do_classifier_free_guidance :
880
+ if self . do_classifier_free_guidance :
735
881
prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
736
882
737
883
# 4. Prepare timesteps
738
- self .scheduler .set_timesteps (num_inference_steps , device = device )
739
- timesteps = self .scheduler .timesteps
884
+ timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
740
885
741
886
# 5. Prepare latent variables
742
887
num_channels_latents = self .unet .config .in_channels
@@ -757,32 +902,59 @@ def __call__(
757
902
# 6.1 Add image embeds for IP-Adapter
758
903
added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
759
904
905
+ # 6.2 Optionally get Guidance Scale Embedding
906
+ timestep_cond = None
907
+ if self .unet .config .time_cond_proj_dim is not None :
908
+ guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
909
+ timestep_cond = self .get_guidance_scale_embedding (
910
+ guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
911
+ ).to (device = device , dtype = latents .dtype )
912
+
760
913
# 7. Denoising loop
761
914
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
915
+ self ._num_timesteps = len (timesteps )
762
916
with self .progress_bar (total = num_inference_steps ) as progress_bar :
763
917
for i , t in enumerate (timesteps ):
918
+ if self .interrupt :
919
+ continue
920
+
764
921
# expand the latents if we are doing classifier free guidance
765
- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
922
+ latent_model_input = torch .cat ([latents ] * 2 ) if self . do_classifier_free_guidance else latents
766
923
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
767
924
768
925
# predict the noise residual
769
926
noise_pred = self .unet (
770
927
latent_model_input ,
771
928
t ,
772
929
encoder_hidden_states = prompt_embeds ,
930
+ timestep_cond = timestep_cond ,
773
931
cross_attention_kwargs = cross_attention_kwargs ,
774
932
added_cond_kwargs = added_cond_kwargs ,
775
933
return_dict = False ,
776
934
)[0 ]
777
935
778
936
# perform guidance
779
- if do_classifier_free_guidance :
937
+ if self . do_classifier_free_guidance :
780
938
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
781
939
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
782
940
941
+ if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
942
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
943
+ noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
944
+
783
945
# compute the previous noisy sample x_t -> x_t-1
784
946
latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
785
947
948
+ if callback_on_step_end is not None :
949
+ callback_kwargs = {}
950
+ for k in callback_on_step_end_tensor_inputs :
951
+ callback_kwargs [k ] = locals ()[k ]
952
+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
953
+
954
+ latents = callback_outputs .pop ("latents" , latents )
955
+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
956
+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
957
+
786
958
# call the callback, if provided
787
959
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
788
960
progress_bar .update ()
0 commit comments