Skip to content

Commit 0a682b0

Browse files
committed
Copy over the changes from #637
1 parent 729273d commit 0a682b0

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

Diff for: src/diffusers/schedulers/scheduling_lms_discrete.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,28 @@ def __init__(
8585
)
8686

8787
if trained_betas is not None:
88-
self.betas = np.asarray(trained_betas)
88+
self.betas = torch.from_numpy(trained_betas)
8989
if beta_schedule == "linear":
90-
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
90+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
9191
elif beta_schedule == "scaled_linear":
9292
# this schedule is very specific to the latent diffusion model.
9393
self.betas = (
94-
np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
94+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
9595
)
9696
else:
9797
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
9898

99-
self.alphas = np.array(1.0 - self.betas, dtype=np.float32)
100-
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
99+
self.alphas = 1.0 - self.betas
100+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
101101

102-
sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
102+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
103+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
103104
self.sigmas = torch.from_numpy(sigmas)
104105

105106
# setable values
106107
self.num_inference_steps = None
107-
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
108+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
109+
self.timesteps = torch.from_numpy(timesteps)
108110
self.derivatives = []
109111

110112
def get_lms_coefficient(self, order, t, current_order):
@@ -138,16 +140,13 @@ def set_timesteps(self, num_inference_steps: int):
138140
the number of diffusion steps used when generating samples with a pre-trained model.
139141
"""
140142
self.num_inference_steps = num_inference_steps
141-
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
142143

143-
low_idx = np.floor(timesteps).astype(int)
144-
high_idx = np.ceil(timesteps).astype(int)
145-
frac = np.mod(timesteps, 1.0)
144+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
146145
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
147-
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
146+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
148147
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
149148
self.sigmas = torch.from_numpy(sigmas)
150-
self.timesteps = timesteps
149+
self.timesteps = torch.from_numpy(timesteps)
151150

152151
self.derivatives = []
153152

0 commit comments

Comments
 (0)