Skip to content

Fix the LMS pytorch regression #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,26 @@ def __init__(
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = np.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas = np.array(1.0 - self.betas, dtype=np.float32)
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.sigmas = torch.from_numpy(sigmas)

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.derivatives = []

def get_lms_coefficient(self, order, t, current_order):
Expand Down Expand Up @@ -146,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int):
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = timesteps

self.timesteps = timesteps.astype(int)
self.derivatives = []

def step(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,5 +876,5 @@ def test_full_loop_no_noise(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 1006.370) < 1e-2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning this to pre-#534 value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3