From 729273d6be481dc69055da97d13149975955639f Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 28 Sep 2022 13:24:02 +0200 Subject: [PATCH 1/4] Fix the LMS pytorch regression --- .../schedulers/scheduling_lms_discrete.py | 17 +++++++++-------- tests/test_scheduler.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6d8db7682db5..165378eac60b 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -85,25 +85,26 @@ def __init__( ) if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = np.asarray(trained_betas) if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.alphas = np.array(1.0 - self.betas, dtype=np.float32) + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1 + self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.derivatives = [] def get_lms_coefficient(self, order, t, current_order): @@ -146,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int): sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + self.timesteps = timesteps - self.timesteps = timesteps.astype(int) self.derivatives = [] def step( diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index cf3e607ea9d2..8601e77a43dd 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -876,5 +876,5 @@ def test_full_loop_no_noise(self): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 1006.370) < 1e-2 + assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 From 0a682b06a41ff452a50671bc1c735455dc6d013f Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 28 Sep 2022 13:32:53 +0200 Subject: [PATCH 2/4] Copy over the changes from #637 --- .../schedulers/scheduling_lms_discrete.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 165378eac60b..4595b2fe5aaf 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -85,26 +85,28 @@ def __init__( ) if trained_betas is not None: - self.betas = np.asarray(trained_betas) + self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( - np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - self.alphas = np.array(1.0 - self.betas, dtype=np.float32) - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] def get_lms_coefficient(self, order, t, current_order): @@ -138,16 +140,13 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) - low_idx = np.floor(timesteps).astype(int) - high_idx = np.ceil(timesteps).astype(int) - frac = np.mod(timesteps, 1.0) + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - self.timesteps = timesteps + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] From f903db0c586f03cbaee1e9fd0c3d8ba3bc08e6dc Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 28 Sep 2022 13:35:39 +0200 Subject: [PATCH 3/4] Copy over the changes from #637 --- src/diffusers/dependency_versions_table.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 09a7baad560d..82ca5dbb6f56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -17,7 +17,6 @@ "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", From b493ad14715a11e6c9452b7208c0c1faaf655384 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 28 Sep 2022 14:05:50 +0200 Subject: [PATCH 4/4] Fix betas test --- tests/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 8601e77a43dd..bee36c39acdb 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -844,7 +844,7 @@ def test_timesteps(self): self.check_over_configs(num_train_timesteps=timesteps) def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self):