diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 2b6e58fe128d..b5af14d4bf4a 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -44,8 +44,7 @@ To this end, the design of schedulers is such that: The core API for any new scheduler must follow a limited structure. - Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively. - Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task. -- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch -with a `set_format(...)` method. +- Schedulers should be framework-specific. The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index f78175735931..33fba10b2687 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -274,7 +274,7 @@ def __call__( # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - # # predict the noise residual + # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform classifier free guidance diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index de5761646a00..53b4cf2f1d5c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -424,7 +424,10 @@ def main(): # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, ) train_dataset = TextualInversionDataset( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f6affe8a1400..243e433a5b83 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -59,7 +59,7 @@ def main(args): "UpBlock2D", ), ) - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + noise_scheduler = DDPMScheduler(num_train_timesteps=1000) optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 17b8ec83a9a4..74607fe87a3d 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 9f62a6a1258a..aae29737aae3 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 968a93e80f09..556e4211892b 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -45,7 +45,6 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index b7104bd709c6..ef82abb7e6cb 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline): def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 483f86da16db..f360da09ac94 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline): def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6bd8f76a0487..77f25ef1b9c5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -57,7 +57,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ca8002129a8b..f2ccee71c024 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -69,7 +69,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1c6ee1045ebb..a95f9152279a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -83,7 +83,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: @@ -320,11 +319,11 @@ def __call__( if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index])) else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index ba09f7274cc6..23bd8fbc367a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -35,7 +35,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("np") self.register_modules( vae_decoder=vae_decoder, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 1e7e7a26a216..35d06106869e 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline): def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md index edf2299446fe..72ec3f023906 100644 --- a/src/diffusers/schedulers/README.md +++ b/src/diffusers/schedulers/README.md @@ -8,8 +8,7 @@ - Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during the forward pass. -- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch -with a `set_format(...)` method. +- Schedulers should be framework specific. ## Examples diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0613ffd41d0e..6880700ecef0 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -72,7 +72,7 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas) class DDIMScheduler(SchedulerMixin, ConfigMixin): @@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -121,15 +120,16 @@ def __init__( clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -137,20 +137,17 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.timesteps = np.arange(0, num_train_timesteps)[::-1] def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -186,15 +183,14 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1] self.timesteps += offset - self.set_format(tensor_format=self.tensor_format) def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, @@ -205,9 +201,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO @@ -251,7 +247,7 @@ def step( # 4. Clip "predicted x_0" if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -273,9 +269,6 @@ def step( noise = torch.randn(model_output.shape, generator=generator).to(device) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - if not torch.is_tensor(model_output): - variance = variance.numpy() - prev_sample = prev_sample + variance if not return_dict: @@ -285,16 +278,20 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 440b880385d4..0383dea224c7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -70,7 +70,7 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas, dtype=torch.float32) class DDPMScheduler(SchedulerMixin, ConfigMixin): @@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -113,15 +112,16 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -129,15 +129,12 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.variance_type = variance_type @@ -153,8 +150,7 @@ def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() - self.set_format(tensor_format=self.tensor_format) + )[::-1] def _get_variance(self, t, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] @@ -170,15 +166,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # hacks - were probably added for training stability if variance_type == "fixed_small": - variance = self.clip(variance, min_value=1e-20) + variance = torch.clamp(variance, min=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": - variance = self.log(self.clip(variance, min_value=1e-20)) + variance = torch.log(torch.clamp(variance, min=1e-20)) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": # Glide max_log - variance = self.log(self.betas[t]) + variance = torch.log(self.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": @@ -191,9 +187,9 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, predict_epsilon=True, generator=None, return_dict: bool = True, @@ -203,9 +199,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. @@ -240,7 +236,7 @@ def step( # 3. Clip "predicted x_0" if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -254,7 +250,9 @@ def step( # 6. Add noise variance = 0 if t > 0: - noise = self.randn_like(model_output, generator=generator) + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -266,16 +264,21 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 98dafc72a734..5826858faee4 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -87,15 +86,11 @@ def __init__( s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, - tensor_format: str = "pt", ): # setable values - self.num_inference_steps = None - self.timesteps = None - self.schedule = None # sigma(t_i) - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.num_inference_steps: int = None + self.timesteps: np.ndarray = None + self.schedule: torch.FloatTensor = None # sigma(t_i) def set_timesteps(self, num_inference_steps: int): """ @@ -108,20 +103,18 @@ def set_timesteps(self, num_inference_steps: int): """ self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - self.schedule = [ + schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps ] - self.schedule = np.array(self.schedule, dtype=np.float32) - - self.set_format(tensor_format=self.tensor_format) + self.schedule = torch.tensor(schedule, dtype=torch.float32) def add_noise_to_input( - self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None - ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: + self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. @@ -142,10 +135,10 @@ def add_noise_to_input( def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], + sample_hat: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ @@ -153,10 +146,10 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). @@ -180,24 +173,24 @@ def step( def step_correct( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], - sample_prev: Union[torch.FloatTensor, np.ndarray], - derivative: Union[torch.FloatTensor, np.ndarray], + sample_hat: torch.FloatTensor, + sample_prev: torch.FloatTensor, + derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO - derivative (`torch.FloatTensor` or `np.ndarray`): TODO + sample_hat (`torch.FloatTensor`): TODO + sample_prev (`torch.FloatTensor`): TODO + derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index e4ce54a3abf0..6167af5ad42b 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -75,31 +74,29 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1 self.derivatives = [] - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -131,24 +128,24 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) - low_idx = np.floor(self.timesteps).astype(int) - high_idx = np.ceil(self.timesteps).astype(int) - frac = np.mod(self.timesteps, 1.0) + low_idx = np.floor(timesteps).astype(int) + high_idx = np.ceil(timesteps).astype(int) + frac = np.mod(timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = timesteps.astype(int) self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) - def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: @@ -157,9 +154,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class @@ -197,15 +194,18 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + sigmas = self.sigmas.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + sigma = sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index fcaed372f045..303ef4e47d61 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -132,7 +132,7 @@ def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: i return state.replace( num_inference_steps=num_inference_steps, - timesteps=timesteps, + timesteps=timesteps.astype(int), derivatives=jnp.array([]), sigmas=sigmas, ) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 09e8a7e240c2..1935a6ef93f2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -51,7 +51,7 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas, dtype=torch.float32) class PNDMScheduler(SchedulerMixin, ConfigMixin): @@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ @@ -101,15 +100,16 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -117,9 +117,9 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf @@ -139,9 +139,6 @@ def __init__( self.plms_timesteps = None self.timesteps = None - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -189,13 +186,12 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor self.ets = [] self.counter = 0 - self.set_format(tensor_format=self.tensor_format) def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -205,9 +201,9 @@ def step( This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -224,9 +220,9 @@ def step( def step_prk( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -234,9 +230,9 @@ def step_prk( solution to the differential equation. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -279,9 +275,9 @@ def step_prk( def step_plms( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -289,9 +285,9 @@ def step_plms( times to approximate the solution. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -381,16 +377,27 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.Tensor: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.alphas_cumprod.device != original_samples.device: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) + + if timesteps.device != original_samples.device: + timesteps = timesteps.to(original_samples.device) + + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 4af8f4fdad7d..7b06ae16c5e9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -14,11 +14,11 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch +import math import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union -import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. - tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. """ @register_to_config @@ -77,16 +76,12 @@ def __init__( sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, - tensor_format: str = "pt", ): # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -98,13 +93,8 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) - elif tensor_format == "pt": - self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None @@ -129,28 +119,16 @@ def set_sigmas( if self.timesteps is None: self.set_timesteps(num_inference_steps, sampling_eps) - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - elif tensor_format == "pt": - self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) + self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) - elif tensor_format == "pt": - return torch.where( - timesteps == 0, - torch.zeros_like(t.to(timesteps.device)), - self.discrete_sigmas[timesteps - 1].to(timesteps.device), - ) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) def set_seed(self, seed): warnings.warn( @@ -158,19 +136,13 @@ def set_seed(self, seed): " generator instead.", DeprecationWarning, ) - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - np.random.seed(seed) - elif tensor_format == "pt": - torch.manual_seed(seed) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + torch.manual_seed(seed) def step_pred( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, **kwargs, @@ -180,9 +152,9 @@ def step_pred( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -210,18 +182,21 @@ def step_pred( sigma = self.discrete_sigmas[timesteps].to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) - drift = self.zeros_like(sample) + drift = torch.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion.unsqueeze(-1) + drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of - noise = self.randn_like(sample, generator=generator) + noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean) @@ -230,8 +205,8 @@ def step_pred( def step_correct( self, - model_output: Union[torch.FloatTensor, np.ndarray], - sample: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, + sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, **kwargs, @@ -241,8 +216,8 @@ def step_correct( after making the prediction for the previous timestep. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sample (`torch.FloatTensor` or `np.ndarray`): + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -262,18 +237,21 @@ def step_correct( # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction - noise = self.randn_like(sample, generator=generator) + noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr - grad_norm = self.norm(model_output) - noise_norm = self.norm(noise) + grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) # self.repeat_scalar(step_size, sample.shape[0]) # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size.unsqueeze(-1) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample,) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index f19a5ad76f81..2f9821579c52 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -16,7 +16,8 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit -import numpy as np +import math + import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -39,7 +40,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None self.discrete_sigmas = None self.timesteps = None @@ -47,7 +48,7 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling def set_timesteps(self, num_inference_steps): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) - def step_pred(self, score, x, t): + def step_pred(self, score, x, t, generator=None): if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" @@ -59,20 +60,27 @@ def step_pred(self, score, x, t): -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min ) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) - score = -score / std[:, None, None, None] + std = std.flatten() + while len(std.shape) < len(score.shape): + std = std.unsqueeze(-1) + score = -score / std # compute dt = -1.0 / len(self.timesteps) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) - drift = -0.5 * beta_t[:, None, None, None] * x + beta_t = beta_t.flatten() + while len(beta_t.shape) < len(x.shape): + beta_t = beta_t.unsqueeze(-1) + drift = -0.5 * beta_t * x + diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * score + drift = drift - diffusion**2 * score x_mean = x + drift * dt # add noise - noise = torch.randn_like(x) - x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise + noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device) + x = x_mean + diffusion * math.sqrt(-dt) * noise return x, x_mean diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f2bcd73acf32..29bf982f6adf 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Union -import numpy as np import torch from ..utils import BaseOutput @@ -43,83 +41,3 @@ class SchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME - ignore_for_config = ["tensor_format"] - - def set_format(self, tensor_format="pt"): - self.tensor_format = tensor_format - if tensor_format == "pt": - for key, value in vars(self).items(): - if isinstance(value, np.ndarray): - setattr(self, key, torch.from_numpy(value)) - - return self - - def clip(self, tensor, min_value=None, max_value=None): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.clip(tensor, min_value, max_value) - elif tensor_format == "pt": - return torch.clamp(tensor, min_value, max_value) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def log(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.log(tensor) - elif tensor_format == "pt": - return torch.log(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): - """ - Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. - - Args: - values: an array or tensor of values to extract. - broadcast_array: an array with a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - Returns: - a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - - tensor_format = getattr(self, "tensor_format", "pt") - values = values.flatten() - - while len(values.shape) < len(broadcast_array.shape): - values = values[..., None] - if tensor_format == "pt": - values = values.to(broadcast_array.device) - - return values - - def norm(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.linalg.norm(tensor) - elif tensor_format == "pt": - return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def randn_like(self, tensor, generator=None): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.random.randn(*np.shape(tensor)) - elif tensor_format == "pt": - # return torch.randn_like(tensor) - return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def zeros_like(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.zeros_like(tensor) - elif tensor_format == "pt": - return torch.zeros_like(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index dddf42bd03f2..61d5ac3a4e28 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -191,7 +191,7 @@ def to(self, device): def test_ddim(self): unet = self.dummy_uncond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) @@ -220,7 +220,7 @@ def test_ddim(self): def test_pndm_cifar10(self): unet = self.dummy_uncond_unet - scheduler = PNDMScheduler(tensor_format="pt") + scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) @@ -242,7 +242,7 @@ def test_pndm_cifar10(self): def test_ldm_text2img(self): unet = self.dummy_cond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -339,7 +339,7 @@ def test_stable_diffusion_ddim(self): def test_stable_diffusion_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -460,7 +460,7 @@ def test_stable_diffusion_attention_chunk(self): def test_score_sde_ve_pipeline(self): unet = self.dummy_uncond_unet - scheduler = ScoreSdeVeScheduler(tensor_format="pt") + scheduler = ScoreSdeVeScheduler() sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) sde_ve.to(torch_device) @@ -484,7 +484,7 @@ def test_score_sde_ve_pipeline(self): def test_ldm_uncond(self): unet = self.dummy_uncond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() vae = self.dummy_vq_model ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) @@ -512,7 +512,7 @@ def test_ldm_uncond(self): def test_karras_ve_pipeline(self): unet = self.dummy_uncond_unet - scheduler = KarrasVeScheduler(tensor_format="pt") + scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) pipe.to(torch_device) @@ -535,7 +535,7 @@ def test_karras_ve_pipeline(self): def test_stable_diffusion_img2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -646,7 +646,7 @@ def test_stable_diffusion_img2img_k_lms(self): def test_stable_diffusion_inpaint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -842,7 +842,6 @@ def test_ddpm_cifar10(self): unet = UNet2DModel.from_pretrained(model_id) scheduler = DDPMScheduler.from_config(model_id) - scheduler = scheduler.set_format("pt") ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) @@ -882,7 +881,7 @@ def test_ddim_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim.to(torch_device) @@ -902,7 +901,7 @@ def test_pndm_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = PNDMScheduler(tensor_format="pt") + scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) @@ -1043,8 +1042,8 @@ def test_ddpm_ddim_equality(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler(tensor_format="pt") - ddim_scheduler = DDIMScheduler(tensor_format="pt") + ddpm_scheduler = DDPMScheduler() + ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) @@ -1067,8 +1066,8 @@ def test_ddpm_ddim_equality_batched(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler(tensor_format="pt") - ddim_scheduler = DDIMScheduler(tensor_format="pt") + ddpm_scheduler = DDPMScheduler() + ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) @@ -1093,7 +1092,7 @@ def test_ddpm_ddim_equality_batched(self): def test_karras_ve_pipeline(self): model_id = "google/ncsnpp-celebahq-256" model = UNet2DModel.from_pretrained(model_id) - scheduler = KarrasVeScheduler(tensor_format="pt") + scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe.to(torch_device) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 7377797bebfa..cf3e607ea9d2 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -173,34 +173,6 @@ def test_step_shape(self): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - def test_pytorch_equal_numpy(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt - - sample = sample_pt.numpy() - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) - - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample - - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): t[t != t] = 0 @@ -266,7 +238,6 @@ def get_scheduler_config(self, **kwargs): "beta_schedule": "linear", "variance_type": "fixed_small", "clip_sample": True, - "tensor_format": "pt", } config.update(**kwargs) @@ -305,10 +276,6 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 - # TODO Make DDPM Numpy compatible - def test_pytorch_equal_numpy(self): - pass - def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -387,7 +354,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all() def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -556,72 +523,6 @@ def full_loop(self, **config): return sample - def test_pytorch_equal_numpy(self): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt - dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] - - sample = sample_pt.numpy() - residual = 0.1 * sample - dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) - - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - # copy over dummy past residuals (must be done after set_timesteps) - scheduler.ets = dummy_past_residuals[:] - scheduler_pt.ets = dummy_past_residuals_pt[:] - - output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - - output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample - - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - - def test_set_format(self): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - - for key, value in vars(scheduler).items(): - # we only allow `ets` attr to be a list - assert not isinstance(value, list) or key in [ - "ets" - ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}" - - # check if `scheduler.set_format` does convert correctly attrs to pt format - for key, value in vars(scheduler_pt).items(): - # we only allow `ets` attr to be a list - assert not isinstance(value, list) or key in [ - "ets" - ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" - assert not isinstance( - value, np.ndarray - ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" - def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -667,12 +568,10 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) - assert torch.equal( + assert np.equal( scheduler.timesteps, - torch.tensor( - [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] - ), - ) + np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): @@ -786,7 +685,6 @@ def get_scheduler_config(self, **kwargs): "sigma_min": 0.01, "sigma_max": 1348, "sampling_eps": 1e-5, - "tensor_format": "pt", # TODO add test for tensor formats } config.update(**kwargs) @@ -936,7 +834,6 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "trained_betas": None, - "tensor_format": "pt", } config.update(**kwargs) @@ -958,28 +855,6 @@ def test_time_indices(self): for t in [0, 500, 800]: self.check_over_forward(time_step=t) - def test_pytorch_equal_numpy(self): - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt - - sample = sample_pt.numpy() - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler_config["tensor_format"] = "np" - scheduler = scheduler_class(**scheduler_config) - - scheduler_config["tensor_format"] = "pt" - scheduler_pt = scheduler_class(**scheduler_config) - - scheduler.set_timesteps(self.num_inference_steps) - scheduler_pt.set_timesteps(self.num_inference_steps) - - output = scheduler.step(residual, 1, sample).prev_sample - output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -1001,5 +876,5 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 1006.388) < 1e-2 + assert abs(result_sum.item() - 1006.370) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 diff --git a/tests/test_training.py b/tests/test_training.py index 519c5ab9e716..41aae07e33c6 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -41,7 +41,6 @@ def test_training_step_equality(self): beta_end=0.02, beta_schedule="linear", clip_sample=True, - tensor_format="pt", ) ddim_scheduler = DDIMScheduler( num_train_timesteps=1000, @@ -49,7 +48,6 @@ def test_training_step_equality(self): beta_end=0.02, beta_schedule="linear", clip_sample=True, - tensor_format="pt", ) assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps