We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b58e883 commit 6c1a5ddCopy full SHA for 6c1a5dd
monai/networks/schedulers/scheduler.py
@@ -105,9 +105,11 @@ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
105
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
106
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
107
alphas_cumprod /= alphas_cumprod[0].item()
108
- alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
109
- betas = 1.0 - alphas
110
- return betas, alphas, alphas_cumprod[:-1]
+ betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ betas = torch.clip(betas, 0.0, 0.999)
+ alphas = 1.0 - betas
111
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
112
+ return betas, alphas, alphas_cumprod
113
114
115
class Scheduler(nn.Module):
0 commit comments