|
10 | 10 | from diffusers.loaders import LoraLoaderMixin
|
11 | 11 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
12 | 12 | from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 13 | +from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
13 | 14 | from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
14 | 15 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
15 | 16 | from diffusers.schedulers import KarrasDiffusionSchedulers
|
@@ -193,7 +194,7 @@ def retrieve_timesteps(
|
193 | 194 | return timesteps, num_inference_steps
|
194 | 195 |
|
195 | 196 |
|
196 |
| -class GlueGenStableDiffusionPipeline(DiffusionPipeline, LoraLoaderMixin): |
| 197 | +class GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, LoraLoaderMixin): |
197 | 198 | def __init__(
|
198 | 199 | self,
|
199 | 200 | vae: AutoencoderKL,
|
@@ -241,35 +242,6 @@ def load_language_adapter(
|
241 | 242 | )
|
242 | 243 | self.language_adapter.load_state_dict(torch.load(model_path))
|
243 | 244 |
|
244 |
| - def enable_vae_slicing(self): |
245 |
| - r""" |
246 |
| - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
247 |
| - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
248 |
| - """ |
249 |
| - self.vae.enable_slicing() |
250 |
| - |
251 |
| - def disable_vae_slicing(self): |
252 |
| - r""" |
253 |
| - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
254 |
| - computing decoding in one step. |
255 |
| - """ |
256 |
| - self.vae.disable_slicing() |
257 |
| - |
258 |
| - def enable_vae_tiling(self): |
259 |
| - r""" |
260 |
| - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
261 |
| - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
262 |
| - processing larger images. |
263 |
| - """ |
264 |
| - self.vae.enable_tiling() |
265 |
| - |
266 |
| - def disable_vae_tiling(self): |
267 |
| - r""" |
268 |
| - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
269 |
| - computing decoding in one step. |
270 |
| - """ |
271 |
| - self.vae.disable_tiling() |
272 |
| - |
273 | 245 | def _adapt_language(self, prompt_embeds: torch.FloatTensor):
|
274 | 246 | prompt_embeds = prompt_embeds / 3
|
275 | 247 | prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2)
|
@@ -544,32 +516,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
|
544 | 516 | latents = latents * self.scheduler.init_noise_sigma
|
545 | 517 | return latents
|
546 | 518 |
|
547 |
| - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): |
548 |
| - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. |
549 |
| -
|
550 |
| - The suffixes after the scaling factors represent the stages where they are being applied. |
551 |
| -
|
552 |
| - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values |
553 |
| - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
554 |
| -
|
555 |
| - Args: |
556 |
| - s1 (`float`): |
557 |
| - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
558 |
| - mitigate "oversmoothing effect" in the enhanced denoising process. |
559 |
| - s2 (`float`): |
560 |
| - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
561 |
| - mitigate "oversmoothing effect" in the enhanced denoising process. |
562 |
| - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
563 |
| - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
564 |
| - """ |
565 |
| - if not hasattr(self, "unet"): |
566 |
| - raise ValueError("The pipeline must have `unet` for using FreeU.") |
567 |
| - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) |
568 |
| - |
569 |
| - def disable_freeu(self): |
570 |
| - """Disables the FreeU mechanism if enabled.""" |
571 |
| - self.unet.disable_freeu() |
572 |
| - |
573 | 519 | # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
574 | 520 | def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
575 | 521 | """
|
|
0 commit comments