Skip to content

Commit 6c1a5dd

Browse files
committed
🐛 fix cosine noise scheduler
Signed-off-by: Slava Shen <[email protected]>
1 parent b58e883 commit 6c1a5dd

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

monai/networks/schedulers/scheduler.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
105105
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
106106
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
107107
alphas_cumprod /= alphas_cumprod[0].item()
108-
alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
109-
betas = 1.0 - alphas
110-
return betas, alphas, alphas_cumprod[:-1]
108+
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
109+
betas = torch.clip(betas, 0.0, 0.999)
110+
alphas = 1.0 - betas
111+
alphas_cumprod = torch.cumprod(alphas, dim=0)
112+
return betas, alphas, alphas_cumprod
111113

112114

113115
class Scheduler(nn.Module):

0 commit comments

Comments
 (0)