From 479e0d2a308e337af2bdb3d858075d76b48ef0a1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Sep 2022 17:52:14 +0200 Subject: [PATCH 01/33] pytorch only schedulers --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 1 - src/diffusers/schedulers/scheduling_ddim.py | 57 ++++++++----------- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 95b49e045f67..d9cca7ee62ef 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -36,7 +36,6 @@ class DDIMPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 85ba67523172..17f37c843b68 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -22,10 +22,10 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SCHEDULER_CONFIG_NAME -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -51,10 +51,10 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas) -class DDIMScheduler(SchedulerMixin, ConfigMixin): +class DDIMScheduler(ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -79,10 +79,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): if alpha for final step is 1 or the final alpha of the "non-previous" one. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. - """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, @@ -93,15 +93,16 @@ def __init__( trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -109,20 +110,17 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.timesteps = np.arange(0, num_train_timesteps)[::-1] def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] @@ -148,15 +146,14 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0): step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1] self.timesteps += offset - self.set_format(tensor_format=self.tensor_format) def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, @@ -167,9 +164,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO @@ -212,7 +209,7 @@ def step( # 4. Clip "predicted x_0" if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -234,9 +231,6 @@ def step( noise = torch.randn(model_output.shape, generator=generator).to(device) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - if not torch.is_tensor(model_output): - variance = variance.numpy() - prev_sample = prev_sample + variance if not return_dict: @@ -246,12 +240,11 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 From 3cc87967e52fee1ebcc7ec226dab84328d90f0bb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Sep 2022 17:55:41 +0200 Subject: [PATCH 02/33] fix style --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 17f37c843b68..6811a16fc727 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput, SCHEDULER_CONFIG_NAME +from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: From c1bb9577a8dd4b874d154358d1f33b191815f473 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 18:34:04 +0200 Subject: [PATCH 03/33] remove match_shape --- src/diffusers/schedulers/scheduling_ddim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a766af5fd099..cda17173e45a 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -265,9 +265,15 @@ def add_noise( ) -> torch.FloatTensor: timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples From 6818e9a656a04133dcce17e8ea557b8b36103610 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 18:36:20 +0200 Subject: [PATCH 04/33] pytorch only ddpm --- src/diffusers/schedulers/scheduling_ddpm.py | 54 +++++++++++---------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index fac75bc43eff..c89647c3b977 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -93,15 +93,16 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -109,15 +110,12 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.one = torch.tensor(1.0) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.variance_type = variance_type @@ -133,8 +131,7 @@ def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() - self.set_format(tensor_format=self.tensor_format) + )[::-1] def _get_variance(self, t, predicted_variance=None, variance_type=None): alpha_prod_t = self.alphas_cumprod[t] @@ -150,15 +147,15 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # hacks - were probably added for training stability if variance_type == "fixed_small": - variance = self.clip(variance, min_value=1e-20) + variance = torch.clamp(variance, min_value=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": - variance = self.log(self.clip(variance, min_value=1e-20)) + variance = torch.log(torch.clamp(variance, min_value=1e-20)) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": # Glide max_log - variance = self.log(self.betas[t]) + variance = torch.log(self.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": @@ -171,9 +168,9 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, predict_epsilon=True, generator=None, return_dict: bool = True, @@ -183,9 +180,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. @@ -220,7 +217,7 @@ def step( # 3. Clip "predicted x_0" if self.config.clip_sample: - pred_original_sample = self.clip(pred_original_sample, -1, 1) + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -234,7 +231,9 @@ def step( # 6. Add noise variance = 0 if t > 0: - noise = self.randn_like(model_output, generator=generator) + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -250,12 +249,17 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples From 66d48b80984ad7ed3432a58ef9419aca335cab7f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 18:37:28 +0200 Subject: [PATCH 05/33] remove SchedulerMixin --- src/diffusers/schedulers/scheduling_ddpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c89647c3b977..ba24f19309d0 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -53,7 +53,7 @@ def alpha_bar(time_step): return np.array(betas, dtype=np.float32) -class DDPMScheduler(SchedulerMixin, ConfigMixin): +class DDPMScheduler(ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. From f656c69781bf0234478c8a16f8705d19b269ecdb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 18:48:41 +0200 Subject: [PATCH 06/33] remove numpy from karras_ve --- .../schedulers/scheduling_karras_ve.py | 48 ++++++++----------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index caf7625fb683..35a9ca49ce35 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -21,7 +21,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin @dataclass @@ -41,7 +40,7 @@ class KarrasVeOutput(BaseOutput): derivative: torch.FloatTensor -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): +class KarrasVeScheduler(ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -70,7 +69,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -83,15 +81,11 @@ def __init__( s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, - tensor_format: str = "pt", ): # setable values - self.num_inference_steps = None - self.timesteps = None - self.schedule = None # sigma(t_i) - - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + self.num_inference_steps: int = None + self.timesteps: np.ndarray = None + self.schedule: torch.FloatTensor = None # sigma(t_i) def set_timesteps(self, num_inference_steps: int): """ @@ -104,20 +98,18 @@ def set_timesteps(self, num_inference_steps: int): """ self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - self.schedule = [ + schedule = [ ( self.config.sigma_max * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps ] - self.schedule = np.array(self.schedule, dtype=np.float32) - - self.set_format(tensor_format=self.tensor_format) + self.schedule = torch.tensor(schedule, dtype=torch.float32) def add_noise_to_input( - self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None - ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: + self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. @@ -138,10 +130,10 @@ def add_noise_to_input( def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], + sample_hat: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ @@ -149,10 +141,10 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). @@ -174,24 +166,24 @@ def step( def step_correct( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, - sample_hat: Union[torch.FloatTensor, np.ndarray], - sample_prev: Union[torch.FloatTensor, np.ndarray], - derivative: Union[torch.FloatTensor, np.ndarray], + sample_hat: torch.FloatTensor, + sample_prev: torch.FloatTensor, + derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO - derivative (`torch.FloatTensor` or `np.ndarray`): TODO + sample_hat (`torch.FloatTensor`): TODO + sample_prev (`torch.FloatTensor`): TODO + derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: From 66febf13979e38f47126ded9308b08f588eea1ca Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 19:00:33 +0200 Subject: [PATCH 07/33] fix types --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index cda17173e45a..a51b4173dfe0 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -261,7 +261,7 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ba24f19309d0..cc86d2cd1a21 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -245,10 +245,10 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 From d6cce91f55de5a6779ce13a8f1233971c112dbbb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 19:00:53 +0200 Subject: [PATCH 08/33] remove numpy from lms_discrete --- .../schedulers/scheduling_lms_discrete.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5857ae70a856..fc8db1af3e3d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,10 +20,10 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerOutput -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): +class LMSDiscreteScheduler(ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -60,17 +60,19 @@ def __init__( tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 @@ -79,9 +81,6 @@ def __init__( self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.derivatives = [] - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -128,9 +127,9 @@ def set_timesteps(self, num_inference_steps: int): def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -139,9 +138,9 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -179,15 +178,17 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + timesteps = timesteps.to(self.sigmas.device) + + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From 1bb1716e948a994e333b3678949e89bc2435908f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 20:18:12 +0200 Subject: [PATCH 09/33] remove numpy from pndm --- src/diffusers/schedulers/scheduling_pndm.py | 67 +++++++++++---------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 09e8a7e240c2..00e531082ab3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -51,10 +51,10 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas, dtype=torch.float32) -class PNDMScheduler(SchedulerMixin, ConfigMixin): +class PNDMScheduler(ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ @@ -101,15 +100,16 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -117,9 +117,9 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf @@ -139,9 +139,6 @@ def __init__( self.plms_timesteps = None self.timesteps = None - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -189,13 +186,12 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor self.ets = [] self.counter = 0 - self.set_format(tensor_format=self.tensor_format) def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -205,9 +201,9 @@ def step( This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -224,9 +220,9 @@ def step( def step_prk( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -234,9 +230,9 @@ def step_prk( solution to the differential equation. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -279,9 +275,9 @@ def step_prk( def step_plms( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -289,9 +285,9 @@ def step_plms( times to approximate the solution. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -381,16 +377,21 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, ) -> torch.Tensor: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.alphas_cumprod.device) + timesteps = timesteps.to(self.alphas_cumprod.device) + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples From bde8899e9e90e14cfaeae6176221676c76143894 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 20:46:35 +0200 Subject: [PATCH 10/33] fix typo --- src/diffusers/schedulers/scheduling_ddim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a51b4173dfe0..a171fe7ff3c2 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -268,7 +268,6 @@ def add_noise( sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() From ceb0b8aee98a06b22b3159ef64c83709229d6a11 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 20:47:01 +0200 Subject: [PATCH 11/33] remove mixin and numpy from sde_vp and ve --- src/diffusers/schedulers/scheduling_sde_ve.py | 80 ++++++------------- src/diffusers/schedulers/scheduling_sde_vp.py | 3 +- 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 4af8f4fdad7d..0fd022a752ed 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import SchedulerOutput @dataclass @@ -43,7 +43,7 @@ class SdeVeOutput(BaseOutput): prev_sample_mean: torch.FloatTensor -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVeScheduler(ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. - tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. """ @register_to_config @@ -77,16 +76,12 @@ def __init__( sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, - tensor_format: str = "pt", ): # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -98,13 +93,8 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) - elif tensor_format == "pt": - self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None @@ -129,28 +119,16 @@ def set_sigmas( if self.timesteps is None: self.set_timesteps(num_inference_steps, sampling_eps) - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - elif tensor_format == "pt": - self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) - self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) + self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) - elif tensor_format == "pt": - return torch.where( - timesteps == 0, - torch.zeros_like(t.to(timesteps.device)), - self.discrete_sigmas[timesteps - 1].to(timesteps.device), - ) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + return torch.where( + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), + ) def set_seed(self, seed): warnings.warn( @@ -158,19 +136,13 @@ def set_seed(self, seed): " generator instead.", DeprecationWarning, ) - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - np.random.seed(seed) - elif tensor_format == "pt": - torch.manual_seed(seed) - else: - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + torch.manual_seed(seed) def step_pred( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], + sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, **kwargs, @@ -180,9 +152,9 @@ def step_pred( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -210,7 +182,7 @@ def step_pred( sigma = self.discrete_sigmas[timesteps].to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) - drift = self.zeros_like(sample) + drift = torch.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) @@ -218,7 +190,7 @@ def step_pred( drift = drift - diffusion[:, None, None, None] ** 2 * model_output # equation 6: sample noise for the diffusion term of - noise = self.randn_like(sample, generator=generator) + noise = torch.randn_like(sample, layout=sample.layout, generator=generator).to(sample.device) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g @@ -230,8 +202,8 @@ def step_pred( def step_correct( self, - model_output: Union[torch.FloatTensor, np.ndarray], - sample: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, + sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, **kwargs, @@ -241,8 +213,8 @@ def step_correct( after making the prediction for the previous timestep. Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sample (`torch.FloatTensor` or `np.ndarray`): + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class @@ -262,11 +234,11 @@ def step_correct( # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction - noise = self.randn_like(sample, generator=generator) + noise = torch.randn_like(sample, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr - grad_norm = self.norm(model_output) - noise_norm = self.norm(noise) + grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) # self.repeat_scalar(step_size, sample.shape[0]) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index f19a5ad76f81..f79c753d6417 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -20,10 +20,9 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): +class ScoreSdeVpScheduler(ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. From 7aa375f33b0d1bea161bbbd4289c19da1be88578 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:15:42 +0200 Subject: [PATCH 12/33] remove remaining tensor_format --- src/diffusers/schedulers/scheduling_ddpm.py | 11 ++++++----- src/diffusers/schedulers/scheduling_karras_ve.py | 3 +++ src/diffusers/schedulers/scheduling_lms_discrete.py | 8 +++----- src/diffusers/schedulers/scheduling_pndm.py | 4 +++- src/diffusers/schedulers/scheduling_sde_ve.py | 8 +++++--- src/diffusers/schedulers/scheduling_sde_vp.py | 5 ++++- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc86d2cd1a21..5367dcc1a067 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput +from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -50,7 +50,7 @@ def alpha_bar(time_step): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) + return torch.tensor(betas, dtype=torch.float32) class DDPMScheduler(ConfigMixin): @@ -79,10 +79,11 @@ class DDPMScheduler(ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, @@ -147,10 +148,10 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): # hacks - were probably added for training stability if variance_type == "fixed_small": - variance = torch.clamp(variance, min_value=1e-20) + variance = torch.clamp(variance, min=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": - variance = torch.log(torch.clamp(variance, min_value=1e-20)) + variance = torch.log(torch.clamp(variance, min=1e-20)) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 35a9ca49ce35..31a074b330cd 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -19,6 +19,7 @@ import numpy as np import torch +from .scheduling_utils import SCHEDULER_CONFIG_NAME from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput @@ -72,6 +73,8 @@ class KarrasVeScheduler(ConfigMixin): """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index fc8db1af3e3d..f8d14da5085d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput +from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput class LMSDiscreteScheduler(ConfigMixin): @@ -45,10 +45,11 @@ class LMSDiscreteScheduler(ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, @@ -57,7 +58,6 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -123,8 +123,6 @@ def set_timesteps(self, num_inference_steps: int): self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) - def step( self, model_output: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 00e531082ab3..e86f13c824bf 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput +from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -89,6 +89,8 @@ class PNDMScheduler(ConfigMixin): """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 0fd022a752ed..cb7863887b8f 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerOutput +from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput @dataclass @@ -67,6 +67,8 @@ class ScoreSdeVeScheduler(ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config def __init__( self, @@ -190,7 +192,7 @@ def step_pred( drift = drift - diffusion[:, None, None, None] ** 2 * model_output # equation 6: sample noise for the diffusion term of - noise = torch.randn_like(sample, layout=sample.layout, generator=generator).to(sample.device) + noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g @@ -234,7 +236,7 @@ def step_correct( # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction - noise = torch.randn_like(sample, layout=sample.layout, generator=generator).to(sample.device) + noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index f79c753d6417..578f131cd2e8 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -19,6 +19,7 @@ import numpy as np import torch +from .scheduling_utils import SCHEDULER_CONFIG_NAME from ..configuration_utils import ConfigMixin, register_to_config @@ -37,8 +38,10 @@ class ScoreSdeVpScheduler(ConfigMixin): """ + config_name = SCHEDULER_CONFIG_NAME + @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None self.discrete_sigmas = None self.timesteps = None From 5eb3763bdf0771c8a30dd5598207a266ab070015 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:16:11 +0200 Subject: [PATCH 13/33] fix style --- src/diffusers/schedulers/scheduling_karras_ve.py | 2 +- src/diffusers/schedulers/scheduling_sde_vp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 31a074b330cd..9281842663e1 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -19,9 +19,9 @@ import numpy as np import torch -from .scheduling_utils import SCHEDULER_CONFIG_NAME from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput +from .scheduling_utils import SCHEDULER_CONFIG_NAME @dataclass diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 578f131cd2e8..2b3b46446cef 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -19,8 +19,8 @@ import numpy as np import torch -from .scheduling_utils import SCHEDULER_CONFIG_NAME from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SCHEDULER_CONFIG_NAME class ScoreSdeVpScheduler(ConfigMixin): From a56ed1f0018d8676181459cd26bc1b4d9b9486a6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:40:58 +0200 Subject: [PATCH 14/33] sigmas has to be torch tensor --- src/diffusers/schedulers/scheduling_lms_discrete.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index f8d14da5085d..589cb70d28e0 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -119,7 +119,8 @@ def set_timesteps(self, num_inference_steps: int): frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) self.derivatives = [] From 7a86f0ccb289f84b8af4ecc28566d3e140505874 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:41:17 +0200 Subject: [PATCH 15/33] removed set_format in readme --- src/diffusers/schedulers/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md index edf2299446fe..72ec3f023906 100644 --- a/src/diffusers/schedulers/README.md +++ b/src/diffusers/schedulers/README.md @@ -8,8 +8,7 @@ - Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during the forward pass. -- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch -with a `set_format(...)` method. +- Schedulers should be framework specific. ## Examples From ee0185cbbe4cac8be4eb2b0993394d639d33c59b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:41:33 +0200 Subject: [PATCH 16/33] remove set format from docs --- docs/source/api/schedulers.mdx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 2b6e58fe128d..b5af14d4bf4a 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -44,8 +44,7 @@ To this end, the design of schedulers is such that: The core API for any new scheduler must follow a limited structure. - Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively. - Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task. -- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch -with a `set_format(...)` method. +- Schedulers should be framework-specific. The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. From 7e214e9d6ef2b292be97e15db5b69fbef9780dc3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:43:07 +0200 Subject: [PATCH 17/33] remove set_format from pipelines --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 1 - .../pipelines/latent_diffusion/pipeline_latent_diffusion.py | 1 - .../pipeline_latent_diffusion_uncond.py | 1 - src/diffusers/pipelines/pndm/pipeline_pndm.py | 1 - .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 - .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 -- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 5 +++-- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 1 - .../stochastic_karras_ve/pipeline_stochastic_karras_ve.py | 1 - 9 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b7f7093e379b..a08ff1def1af 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -36,7 +36,6 @@ class DDPMPipeline(DiffusionPipeline): def __init__(self, unet, scheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 8caa11dbdf76..feab6d1d2038 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -46,7 +46,6 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 5574b65df9f8..aa80c97401d9 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -24,7 +24,6 @@ class LDMPipeline(DiffusionPipeline): def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index ae6c10e9e967..d5a0d8c3d362 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -40,7 +40,6 @@ class PNDMPipeline(DiffusionPipeline): def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9f1211b43013..0b7e0adc93e0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -53,7 +53,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e7adb4d1a33b..aec8d3b96319 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -65,7 +65,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( @@ -266,7 +265,6 @@ def __call__( # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9ad36f1a2bf..4ffba6570713 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -83,7 +83,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("pt") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: @@ -302,7 +301,9 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise, torch.tensor([t] * batch_size, dtype=torch.long, device=self.device) + ) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # scale and decode the image latents with vae diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index ccba29ade5d3..a5dfd988b2bc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -31,7 +31,6 @@ def __init__( feature_extractor: CLIPFeatureExtractor, ): super().__init__() - scheduler = scheduler.set_format("np") self.register_modules( vae_decoder=vae_decoder, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 1984a25ac0c6..b49c540121e7 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -30,7 +30,6 @@ class KarrasVePipeline(DiffusionPipeline): def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() - scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() From 5a608dc54e443b0097101cd139bc796bb6f9e917 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:43:31 +0200 Subject: [PATCH 18/33] update tests --- .../textual_inversion/textual_inversion.py | 5 +- .../train_unconditional.py | 2 +- tests/test_pipelines.py | 33 ++- tests/test_scheduler.py | 189 +++++++++--------- tests/test_training.py | 2 - 5 files changed, 113 insertions(+), 118 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index de5761646a00..53b4cf2f1d5c 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -424,7 +424,10 @@ def main(): # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, ) train_dataset = TextualInversionDataset( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f6affe8a1400..243e433a5b83 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -59,7 +59,7 @@ def main(args): "UpBlock2D", ), ) - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + noise_scheduler = DDPMScheduler(num_train_timesteps=1000) optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a93e4b..c32e53b936e8 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -192,7 +192,7 @@ def to(self, device): def test_ddim(self): unet = self.dummy_uncond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) @@ -221,7 +221,7 @@ def test_ddim(self): def test_pndm_cifar10(self): unet = self.dummy_uncond_unet - scheduler = PNDMScheduler(tensor_format="pt") + scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) @@ -243,7 +243,7 @@ def test_pndm_cifar10(self): def test_ldm_text2img(self): unet = self.dummy_cond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -340,7 +340,7 @@ def test_stable_diffusion_ddim(self): def test_stable_diffusion_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -461,7 +461,7 @@ def test_stable_diffusion_attention_chunk(self): def test_score_sde_ve_pipeline(self): unet = self.dummy_uncond_unet - scheduler = ScoreSdeVeScheduler(tensor_format="pt") + scheduler = ScoreSdeVeScheduler() sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) sde_ve.to(torch_device) @@ -485,7 +485,7 @@ def test_score_sde_ve_pipeline(self): def test_ldm_uncond(self): unet = self.dummy_uncond_unet - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() vae = self.dummy_vq_model ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) @@ -513,7 +513,7 @@ def test_ldm_uncond(self): def test_karras_ve_pipeline(self): unet = self.dummy_uncond_unet - scheduler = KarrasVeScheduler(tensor_format="pt") + scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) pipe.to(torch_device) @@ -536,7 +536,7 @@ def test_karras_ve_pipeline(self): def test_stable_diffusion_img2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -647,7 +647,7 @@ def test_stable_diffusion_img2img_k_lms(self): def test_stable_diffusion_inpaint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet - scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) + scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") @@ -843,7 +843,6 @@ def test_ddpm_cifar10(self): unet = UNet2DModel.from_pretrained(model_id) scheduler = DDPMScheduler.from_config(model_id) - scheduler = scheduler.set_format("pt") ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) @@ -883,7 +882,7 @@ def test_ddim_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler(tensor_format="pt") + scheduler = DDIMScheduler() ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim.to(torch_device) @@ -903,7 +902,7 @@ def test_pndm_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = PNDMScheduler(tensor_format="pt") + scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) @@ -1044,8 +1043,8 @@ def test_ddpm_ddim_equality(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler(tensor_format="pt") - ddim_scheduler = DDIMScheduler(tensor_format="pt") + ddpm_scheduler = DDPMScheduler() + ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) @@ -1068,8 +1067,8 @@ def test_ddpm_ddim_equality_batched(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - ddpm_scheduler = DDPMScheduler(tensor_format="pt") - ddim_scheduler = DDIMScheduler(tensor_format="pt") + ddpm_scheduler = DDPMScheduler() + ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) @@ -1094,7 +1093,7 @@ def test_ddpm_ddim_equality_batched(self): def test_karras_ve_pipeline(self): model_id = "google/ncsnpp-celebahq-256" model = UNet2DModel.from_pretrained(model_id) - scheduler = KarrasVeScheduler(tensor_format="pt") + scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe.to(torch_device) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 7377797bebfa..007a008b5f62 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -173,33 +173,33 @@ def test_step_shape(self): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - def test_pytorch_equal_numpy(self): - kwargs = dict(self.forward_default_kwargs) + # def test_pytorch_equal_numpy(self): + # kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + # num_inference_steps = kwargs.pop("num_inference_steps", None) - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt + # for scheduler_class in self.scheduler_classes: + # sample_pt = self.dummy_sample + # residual_pt = 0.1 * sample_pt - sample = sample_pt.numpy() - residual = 0.1 * sample + # sample = sample_pt.numpy() + # residual = 0.1 * sample - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) + # scheduler_config = self.get_scheduler_config() + # scheduler = scheduler_class(tensor_format="np", **scheduler_config) - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + # scheduler.set_timesteps(num_inference_steps) + # scheduler_pt.set_timesteps(num_inference_steps) + # elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + # kwargs["num_inference_steps"] = num_inference_steps - output = scheduler.step(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample + # output = scheduler.step(residual, 1, sample, **kwargs).prev_sample + # output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): @@ -266,7 +266,6 @@ def get_scheduler_config(self, **kwargs): "beta_schedule": "linear", "variance_type": "fixed_small", "clip_sample": True, - "tensor_format": "pt", } config.update(**kwargs) @@ -387,7 +386,7 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(5) - assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all() def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): @@ -556,71 +555,71 @@ def full_loop(self, **config): return sample - def test_pytorch_equal_numpy(self): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + # def test_pytorch_equal_numpy(self): + # kwargs = dict(self.forward_default_kwargs) + # num_inference_steps = kwargs.pop("num_inference_steps", None) - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt - dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] + # for scheduler_class in self.scheduler_classes: + # sample_pt = self.dummy_sample + # residual_pt = 0.1 * sample_pt + # dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] - sample = sample_pt.numpy() - residual = 0.1 * sample - dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + # sample = sample_pt.numpy() + # residual = 0.1 * sample + # dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) + # scheduler_config = self.get_scheduler_config() + # scheduler = scheduler_class(tensor_format="np", **scheduler_config) - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + # scheduler.set_timesteps(num_inference_steps) + # scheduler_pt.set_timesteps(num_inference_steps) + # elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + # kwargs["num_inference_steps"] = num_inference_steps - # copy over dummy past residuals (must be done after set_timesteps) - scheduler.ets = dummy_past_residuals[:] - scheduler_pt.ets = dummy_past_residuals_pt[:] + # # copy over dummy past residuals (must be done after set_timesteps) + # scheduler.ets = dummy_past_residuals[:] + # scheduler_pt.ets = dummy_past_residuals_pt[:] - output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + # output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample + # output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample + # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample - output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample + # output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample + # output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - def test_set_format(self): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + # def test_set_format(self): + # kwargs = dict(self.forward_default_kwargs) + # num_inference_steps = kwargs.pop("num_inference_steps", None) - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(tensor_format="np", **scheduler_config) - scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + # for scheduler_class in self.scheduler_classes: + # scheduler_config = self.get_scheduler_config() + # scheduler = scheduler_class(tensor_format="np", **scheduler_config) + # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - scheduler.set_timesteps(num_inference_steps) - scheduler_pt.set_timesteps(num_inference_steps) - - for key, value in vars(scheduler).items(): - # we only allow `ets` attr to be a list - assert not isinstance(value, list) or key in [ - "ets" - ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}" - - # check if `scheduler.set_format` does convert correctly attrs to pt format - for key, value in vars(scheduler_pt).items(): - # we only allow `ets` attr to be a list - assert not isinstance(value, list) or key in [ - "ets" - ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" - assert not isinstance( - value, np.ndarray - ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" + # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + # scheduler.set_timesteps(num_inference_steps) + # scheduler_pt.set_timesteps(num_inference_steps) + + # for key, value in vars(scheduler).items(): + # # we only allow `ets` attr to be a list + # assert not isinstance(value, list) or key in [ + # "ets" + # ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}" + + # # check if `scheduler.set_format` does convert correctly attrs to pt format + # for key, value in vars(scheduler_pt).items(): + # # we only allow `ets` attr to be a list + # assert not isinstance(value, list) or key in [ + # "ets" + # ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" + # assert not isinstance( + # value, np.ndarray + # ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -667,12 +666,10 @@ def test_steps_offset(self): scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(10) - assert torch.equal( + assert np.equal( scheduler.timesteps, - torch.tensor( - [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] - ), - ) + np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): @@ -786,7 +783,6 @@ def get_scheduler_config(self, **kwargs): "sigma_min": 0.01, "sigma_max": 1348, "sampling_eps": 1e-5, - "tensor_format": "pt", # TODO add test for tensor formats } config.update(**kwargs) @@ -936,7 +932,6 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "trained_betas": None, - "tensor_format": "pt", } config.update(**kwargs) @@ -958,27 +953,27 @@ def test_time_indices(self): for t in [0, 500, 800]: self.check_over_forward(time_step=t) - def test_pytorch_equal_numpy(self): - for scheduler_class in self.scheduler_classes: - sample_pt = self.dummy_sample - residual_pt = 0.1 * sample_pt + # def test_pytorch_equal_numpy(self): + # for scheduler_class in self.scheduler_classes: + # sample_pt = self.dummy_sample + # residual_pt = 0.1 * sample_pt - sample = sample_pt.numpy() - residual = 0.1 * sample + # sample = sample_pt.numpy() + # residual = 0.1 * sample - scheduler_config = self.get_scheduler_config() - scheduler_config["tensor_format"] = "np" - scheduler = scheduler_class(**scheduler_config) + # scheduler_config = self.get_scheduler_config() + # scheduler_config["tensor_format"] = "np" + # scheduler = scheduler_class(**scheduler_config) - scheduler_config["tensor_format"] = "pt" - scheduler_pt = scheduler_class(**scheduler_config) + # scheduler_config["tensor_format"] = "pt" + # scheduler_pt = scheduler_class(**scheduler_config) - scheduler.set_timesteps(self.num_inference_steps) - scheduler_pt.set_timesteps(self.num_inference_steps) + # scheduler.set_timesteps(self.num_inference_steps) + # scheduler_pt.set_timesteps(self.num_inference_steps) - output = scheduler.step(residual, 1, sample).prev_sample - output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + # output = scheduler.step(residual, 1, sample).prev_sample + # output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample + # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] diff --git a/tests/test_training.py b/tests/test_training.py index 519c5ab9e716..41aae07e33c6 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -41,7 +41,6 @@ def test_training_step_equality(self): beta_end=0.02, beta_schedule="linear", clip_sample=True, - tensor_format="pt", ) ddim_scheduler = DDIMScheduler( num_train_timesteps=1000, @@ -49,7 +48,6 @@ def test_training_step_equality(self): beta_end=0.02, beta_schedule="linear", clip_sample=True, - tensor_format="pt", ) assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps From 9014aee2a1b92d9b0229f901c4304658c332665a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Sep 2022 21:59:01 +0200 Subject: [PATCH 19/33] fix typo --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index aec8d3b96319..d27df65000aa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -69,7 +69,7 @@ def __init__( if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" From 03dd2874fe1e3c9c56e7f21720b7086b4255fd0c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 11:31:01 +0200 Subject: [PATCH 20/33] continue to use mixin --- src/diffusers/schedulers/scheduling_ddim.py | 6 +- src/diffusers/schedulers/scheduling_ddpm.py | 6 +- .../schedulers/scheduling_karras_ve.py | 6 +- .../schedulers/scheduling_lms_discrete.py | 6 +- src/diffusers/schedulers/scheduling_pndm.py | 6 +- src/diffusers/schedulers/scheduling_sde_ve.py | 6 +- src/diffusers/schedulers/scheduling_sde_vp.py | 6 +- src/diffusers/schedulers/scheduling_utils.py | 80 ------------------- 8 files changed, 14 insertions(+), 108 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a171fe7ff3c2..d33f8612ca6c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: @@ -55,7 +55,7 @@ def alpha_bar(time_step): return torch.tensor(betas) -class DDIMScheduler(ConfigMixin): +class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. @@ -89,8 +89,6 @@ class DDIMScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5367dcc1a067..affdac9f1d96 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -53,7 +53,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class DDPMScheduler(ConfigMixin): +class DDPMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. @@ -82,8 +82,6 @@ class DDPMScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 9281842663e1..5958b2a2f0d3 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SCHEDULER_CONFIG_NAME +from .scheduling_utils import SchedulerMixin @dataclass @@ -41,7 +41,7 @@ class KarrasVeOutput(BaseOutput): derivative: torch.FloatTensor -class KarrasVeScheduler(ConfigMixin): +class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. @@ -73,8 +73,6 @@ class KarrasVeScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 589cb70d28e0..9e5320353de2 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,10 +20,10 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SchedulerMixin -class LMSDiscreteScheduler(ConfigMixin): +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: @@ -48,8 +48,6 @@ class LMSDiscreteScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e86f13c824bf..bd2e1b91bfea 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -54,7 +54,7 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) -class PNDMScheduler(ConfigMixin): +class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. @@ -89,8 +89,6 @@ class PNDMScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index cb7863887b8f..6c40a09fa196 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerOutput +from .scheduling_utils import SchedulerOutput, SchedulerMixin @dataclass @@ -43,7 +43,7 @@ class SdeVeOutput(BaseOutput): prev_sample_mean: torch.FloatTensor -class ScoreSdeVeScheduler(ConfigMixin): +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. @@ -67,8 +67,6 @@ class ScoreSdeVeScheduler(ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 2b3b46446cef..78acc3b36a64 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -20,10 +20,10 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SCHEDULER_CONFIG_NAME +from .scheduling_utils import SchedulerMixin -class ScoreSdeVpScheduler(ConfigMixin): +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. @@ -38,8 +38,6 @@ class ScoreSdeVpScheduler(ConfigMixin): """ - config_name = SCHEDULER_CONFIG_NAME - @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f2bcd73acf32..a47aaea0781c 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -43,83 +43,3 @@ class SchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME - ignore_for_config = ["tensor_format"] - - def set_format(self, tensor_format="pt"): - self.tensor_format = tensor_format - if tensor_format == "pt": - for key, value in vars(self).items(): - if isinstance(value, np.ndarray): - setattr(self, key, torch.from_numpy(value)) - - return self - - def clip(self, tensor, min_value=None, max_value=None): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.clip(tensor, min_value, max_value) - elif tensor_format == "pt": - return torch.clamp(tensor, min_value, max_value) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def log(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - - if tensor_format == "np": - return np.log(tensor) - elif tensor_format == "pt": - return torch.log(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): - """ - Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. - - Args: - values: an array or tensor of values to extract. - broadcast_array: an array with a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - Returns: - a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - - tensor_format = getattr(self, "tensor_format", "pt") - values = values.flatten() - - while len(values.shape) < len(broadcast_array.shape): - values = values[..., None] - if tensor_format == "pt": - values = values.to(broadcast_array.device) - - return values - - def norm(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.linalg.norm(tensor) - elif tensor_format == "pt": - return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean() - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def randn_like(self, tensor, generator=None): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.random.randn(*np.shape(tensor)) - elif tensor_format == "pt": - # return torch.randn_like(tensor) - return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - - def zeros_like(self, tensor): - tensor_format = getattr(self, "tensor_format", "pt") - if tensor_format == "np": - return np.zeros_like(tensor) - elif tensor_format == "pt": - return torch.zeros_like(tensor) - - raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") From 7822977e55344d17752d84ab132b07dbc17a8505 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 11:31:15 +0200 Subject: [PATCH 21/33] fix imports --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d33f8612ca6c..61169fef33bc 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput, SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index affdac9f1d96..c4aa78fbcf41 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -21,7 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput, SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9e5320353de2..9cbb7f3bad3f 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -20,7 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput, SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index bd2e1b91bfea..0b1af26acb8c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerOutput, SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 6c40a09fa196..59fbf985e22e 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import SchedulerOutput, SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput @dataclass From 4c91e362dbcf2ed7638bcb4aff96466b4a3991e6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 11:40:25 +0200 Subject: [PATCH 22/33] removed unsed imports --- src/diffusers/schedulers/scheduling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a47aaea0781c..29bf982f6adf 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Union -import numpy as np import torch from ..utils import BaseOutput From 32c3d57f6dc91281761e4ce849df2778b1e2f02e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 12:13:35 +0200 Subject: [PATCH 23/33] match shape instead of assuming image shapes --- src/diffusers/schedulers/scheduling_sde_ve.py | 14 +++++++++---- src/diffusers/schedulers/scheduling_sde_vp.py | 20 +++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 59fbf985e22e..1fa74fdebf34 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -187,13 +187,16 @@ def step_pred( # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion.unsqueeze(-1) + drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean) @@ -244,8 +247,11 @@ def step_correct( # self.repeat_scalar(step_size, sample.shape[0]) # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size.unsqueeze(-1) + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample,) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 78acc3b36a64..0e4c1d587bd2 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -16,6 +16,7 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit +from operator import ge import numpy as np import torch @@ -47,7 +48,7 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling def set_timesteps(self, num_inference_steps): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) - def step_pred(self, score, x, t): + def step_pred(self, score, x, t, generator=None): if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" @@ -59,20 +60,27 @@ def step_pred(self, score, x, t): -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min ) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) - score = -score / std[:, None, None, None] + std = std.flatten() + while len(std.shape) < len(score.shape): + std = std.unsqueeze(-1) + score = -score / std # compute dt = -1.0 / len(self.timesteps) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) - drift = -0.5 * beta_t[:, None, None, None] * x + beta_t = beta_t.flatten() + while len(beta_t.shape) < len(x.shape): + beta_t = beta_t.unsqueeze(-1) + drift = -0.5 * beta_t * x + diffusion = torch.sqrt(beta_t) - drift = drift - diffusion[:, None, None, None] ** 2 * score + drift = drift - diffusion**2 * score x_mean = x + drift * dt # add noise - noise = torch.randn_like(x) - x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise + noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device) + x = x_mean + diffusion * np.sqrt(-dt) * noise return x, x_mean From 3f484f382068ce369605fee809b0d07b2065e501 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Sep 2022 12:20:28 +0200 Subject: [PATCH 24/33] remove import typo --- src/diffusers/schedulers/scheduling_sde_vp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 0e4c1d587bd2..7a9ab8e36cee 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -16,7 +16,6 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit -from operator import ge import numpy as np import torch From 1b5fbc09a4826a15091786127b299603a48c84cf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Sep 2022 22:12:15 +0200 Subject: [PATCH 25/33] update call to add_noise --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6ea72234c671..d4625377db4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -315,7 +315,7 @@ def __call__( else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 0b1af26acb8c..9cae6f0d879d 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -379,7 +379,7 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.Tensor: timesteps = timesteps.to(self.alphas_cumprod.device) From 0c8e83f63d1f933d0823ccd215cc1b47c83c49bf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Sep 2022 22:24:17 +0200 Subject: [PATCH 26/33] use math instead of numpy --- src/diffusers/schedulers/scheduling_sde_ve.py | 4 ++-- src/diffusers/schedulers/scheduling_sde_vp.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 1fa74fdebf34..7b06ae16c5e9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -14,11 +14,11 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch +import math import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union -import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -120,7 +120,7 @@ def set_sigmas( self.set_timesteps(num_inference_steps, sampling_eps) self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) - self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 7a9ab8e36cee..2f9821579c52 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -16,7 +16,8 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit -import numpy as np +import math + import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -79,7 +80,7 @@ def step_pred(self, score, x, t, generator=None): # add noise noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device) - x = x_mean + diffusion * np.sqrt(-dt) * noise + x = x_mean + diffusion * math.sqrt(-dt) * noise return x, x_mean From a4e80d8216f993d3adc0e3660cf67080c1a9f4e1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Sep 2022 22:31:34 +0200 Subject: [PATCH 27/33] fix t_index --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d4625377db4b..d7e06b7fd7b0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -311,7 +311,7 @@ def __call__( if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index])) else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking From 64856450fd52181131ae685153139bff5be3604c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 23 Sep 2022 22:40:37 +0200 Subject: [PATCH 28/33] removed commented out numpy tests --- tests/test_scheduler.py | 120 ---------------------------------------- 1 file changed, 120 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 007a008b5f62..8601e77a43dd 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -173,34 +173,6 @@ def test_step_shape(self): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - # def test_pytorch_equal_numpy(self): - # kwargs = dict(self.forward_default_kwargs) - - # num_inference_steps = kwargs.pop("num_inference_steps", None) - - # for scheduler_class in self.scheduler_classes: - # sample_pt = self.dummy_sample - # residual_pt = 0.1 * sample_pt - - # sample = sample_pt.numpy() - # residual = 0.1 * sample - - # scheduler_config = self.get_scheduler_config() - # scheduler = scheduler_class(tensor_format="np", **scheduler_config) - - # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - # scheduler.set_timesteps(num_inference_steps) - # scheduler_pt.set_timesteps(num_inference_steps) - # elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - # kwargs["num_inference_steps"] = num_inference_steps - - # output = scheduler.step(residual, 1, sample, **kwargs).prev_sample - # output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample - - # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): t[t != t] = 0 @@ -304,10 +276,6 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 - # TODO Make DDPM Numpy compatible - def test_pytorch_equal_numpy(self): - pass - def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() @@ -555,72 +523,6 @@ def full_loop(self, **config): return sample - # def test_pytorch_equal_numpy(self): - # kwargs = dict(self.forward_default_kwargs) - # num_inference_steps = kwargs.pop("num_inference_steps", None) - - # for scheduler_class in self.scheduler_classes: - # sample_pt = self.dummy_sample - # residual_pt = 0.1 * sample_pt - # dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] - - # sample = sample_pt.numpy() - # residual = 0.1 * sample - # dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] - - # scheduler_config = self.get_scheduler_config() - # scheduler = scheduler_class(tensor_format="np", **scheduler_config) - - # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - # scheduler.set_timesteps(num_inference_steps) - # scheduler_pt.set_timesteps(num_inference_steps) - # elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - # kwargs["num_inference_steps"] = num_inference_steps - - # # copy over dummy past residuals (must be done after set_timesteps) - # scheduler.ets = dummy_past_residuals[:] - # scheduler_pt.ets = dummy_past_residuals_pt[:] - - # output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample - # output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample - # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - - # output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample - # output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample - - # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - - # def test_set_format(self): - # kwargs = dict(self.forward_default_kwargs) - # num_inference_steps = kwargs.pop("num_inference_steps", None) - - # for scheduler_class in self.scheduler_classes: - # scheduler_config = self.get_scheduler_config() - # scheduler = scheduler_class(tensor_format="np", **scheduler_config) - # scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - - # if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - # scheduler.set_timesteps(num_inference_steps) - # scheduler_pt.set_timesteps(num_inference_steps) - - # for key, value in vars(scheduler).items(): - # # we only allow `ets` attr to be a list - # assert not isinstance(value, list) or key in [ - # "ets" - # ], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}" - - # # check if `scheduler.set_format` does convert correctly attrs to pt format - # for key, value in vars(scheduler_pt).items(): - # # we only allow `ets` attr to be a list - # assert not isinstance(value, list) or key in [ - # "ets" - # ], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" - # assert not isinstance( - # value, np.ndarray - # ), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}" - def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -953,28 +855,6 @@ def test_time_indices(self): for t in [0, 500, 800]: self.check_over_forward(time_step=t) - # def test_pytorch_equal_numpy(self): - # for scheduler_class in self.scheduler_classes: - # sample_pt = self.dummy_sample - # residual_pt = 0.1 * sample_pt - - # sample = sample_pt.numpy() - # residual = 0.1 * sample - - # scheduler_config = self.get_scheduler_config() - # scheduler_config["tensor_format"] = "np" - # scheduler = scheduler_class(**scheduler_config) - - # scheduler_config["tensor_format"] = "pt" - # scheduler_pt = scheduler_class(**scheduler_config) - - # scheduler.set_timesteps(self.num_inference_steps) - # scheduler_pt.set_timesteps(self.num_inference_steps) - - # output = scheduler.step(residual, 1, sample).prev_sample - # output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample - # assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() From f0b7aba3b0ed09e80ed27ced9eaa3aa1eecdc108 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 24 Sep 2022 13:23:33 +0200 Subject: [PATCH 29/33] timesteps needs to be discrete --- examples/community/clip_guided_stable_diffusion.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 11 ++++++----- tests/test_scheduler.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index f78175735931..33fba10b2687 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -274,7 +274,7 @@ def __call__( # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - # # predict the noise residual + # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform classifier free guidance diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9b5fe7c71583..0db90cba27a7 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -96,7 +96,7 @@ def __init__( # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1 self.derivatives = [] def get_lms_coefficient(self, order, t, current_order): @@ -130,16 +130,17 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) - low_idx = np.floor(self.timesteps).astype(int) - high_idx = np.ceil(self.timesteps).astype(int) - frac = np.mod(self.timesteps, 1.0) + low_idx = np.floor(timesteps).astype(int) + high_idx = np.ceil(timesteps).astype(int) + frac = np.mod(timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + self.timesteps = timesteps.astype(int) self.derivatives = [] def step( diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 8601e77a43dd..cf3e607ea9d2 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -876,5 +876,5 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 1006.388) < 1e-2 + assert abs(result_sum.item() - 1006.370) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 From 2cb7a3373e9b4f8a78bc277b65b86a24804675f2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Sep 2022 11:25:26 +0200 Subject: [PATCH 30/33] cast timesteps to int in flax scheduler too --- src/diffusers/schedulers/scheduling_lms_discrete_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 7f4c076b54d1..20080c200dad 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -134,7 +134,7 @@ def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: i return state.replace( num_inference_steps=num_inference_steps, - timesteps=timesteps, + timesteps=timesteps.astype(int), derivatives=jnp.array([]), sigmas=sigmas, ) From 9cacd3fd3839863e01cd0d5a362af787c6169703 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 26 Sep 2022 19:35:18 +0200 Subject: [PATCH 31/33] fix device mismatch issue --- src/diffusers/schedulers/scheduling_lms_discrete.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0db90cba27a7..8be5197f10a7 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -200,9 +200,10 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: - timesteps = timesteps.to(self.sigmas.device) + sigmas = self.sigmas.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) - sigma = self.sigmas[timesteps].flatten() + sigma = sigmas[timesteps].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) From 4e70c10f557ec3b6404e51a9b7fc438a8af1889b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 12:54:01 +0000 Subject: [PATCH 32/33] small fix --- src/diffusers/schedulers/scheduling_pndm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 9cae6f0d879d..6b1c2210b87a 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -381,6 +381,13 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + if self.alphas_cumprod.device != original_samples.device: + self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device) + + if timesteps.device != original_samples.device: + timesteps = timesteps.to(original_samples.device) + + self.alphas_cumprod timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 From b2dbd79ca8ee5dd4ebd6c2146f310f44f9bbb8aa Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 14:59:46 +0200 Subject: [PATCH 33/33] Update src/diffusers/schedulers/scheduling_pndm.py --- src/diffusers/schedulers/scheduling_pndm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 6b1c2210b87a..1935a6ef93f2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -387,7 +387,6 @@ def add_noise( if timesteps.device != original_samples.device: timesteps = timesteps.to(original_samples.device) - self.alphas_cumprod timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5