Skip to content

Commit e55f460

Browse files
authored
Fix the LMS pytorch regression (huggingface#664)
* Fix the LMS pytorch regression * Copy over the changes from huggingface#637 * Copy over the changes from huggingface#637 * Fix betas test
1 parent ade5778 commit e55f460

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

Diff for: dependency_versions_table.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
1818
"modelcards": "modelcards>=0.1.4",
1919
"numpy": "numpy",
20-
"onnxruntime": "onnxruntime",
2120
"onnxruntime-gpu": "onnxruntime-gpu",
2221
"pytest": "pytest",
2322
"pytest-timeout": "pytest-timeout",

Diff for: schedulers/scheduling_lms_discrete.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,14 @@ def __init__(
9999
self.alphas = 1.0 - self.betas
100100
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
101101

102-
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
102+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
103+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
104+
self.sigmas = torch.from_numpy(sigmas)
103105

104106
# setable values
105107
self.num_inference_steps = None
106-
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
108+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
109+
self.timesteps = torch.from_numpy(timesteps)
107110
self.derivatives = []
108111

109112
def get_lms_coefficient(self, order, t, current_order):
@@ -137,17 +140,14 @@ def set_timesteps(self, num_inference_steps: int):
137140
the number of diffusion steps used when generating samples with a pre-trained model.
138141
"""
139142
self.num_inference_steps = num_inference_steps
140-
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
141143

142-
low_idx = np.floor(timesteps).astype(int)
143-
high_idx = np.ceil(timesteps).astype(int)
144-
frac = np.mod(timesteps, 1.0)
144+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
145145
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
146-
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
146+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
147147
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
148148
self.sigmas = torch.from_numpy(sigmas)
149+
self.timesteps = torch.from_numpy(timesteps)
149150

150-
self.timesteps = timesteps.astype(int)
151151
self.derivatives = []
152152

153153
def step(

0 commit comments

Comments
 (0)