Skip to content

🐛 fix cosine noise scheduler #8427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

slavaheroes
Copy link

Fixes: Update cosine noise scheduling

Description

In the current DDPMScheduler implementation, using the cosine noise schedule results in a division-by-zero issue during sampling. Specifically, scheduler.alphas_cumprod[0] == 1.0, which causes NaN values in the output image.

You can reproduce the issue with the following snippet:

from monai.inferers import DiffusionInferer
from monai.networks.nets import DiffusionModelUNet
from monai.networks.schedulers import DDPMScheduler

import torch 
N = 250

device = 'cuda'

model = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=[64, 64, 128],
    attention_levels=[False, False, True],
    num_head_channels=[0, 0, 128],
    num_res_blocks=2,
).to(device)

scheduler = DDPMScheduler(num_train_timesteps=N, schedule="cosine").to(device)
print(scheduler.alphas_cumprod[0])

inferer = DiffusionInferer(scheduler)

scheduler.set_timesteps(num_inference_steps=N)

noise = torch.randn((1, 1, 32, 40, 32))
noise = noise.to(device)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)

assert torch.isfinite(image).any(), "Image has NaN values"

Fix

This PR applies clipping to the beta values first and re-computes alphas_cumprod accordingly before returning in monai/networks/schedulers/scheduler.py: 112. This ensures numerical stability during sampling and prevents NaNs in the output.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@slavaheroes slavaheroes force-pushed the dev branch 2 times, most recently from a82fe6a to 723c584 Compare April 23, 2025 05:19
@slavaheroes slavaheroes marked this pull request as ready for review April 23, 2025 06:37
@ericspod ericspod requested a review from virginiafdez April 23, 2025 09:58
@ericspod
Copy link
Member

@virginiafdez could you please check this? Thanks!

Copy link
Contributor

@virginiafdez virginiafdez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new implementation is more inline with the paper (https://arxiv.org/abs/2102.09672). I've tried with several number of time steps, all seems to be in order.

@ericspod
Copy link
Member

Thanks @slavaheroes @virginiafdez we'll put these changes through then. @KumoLiu please trigger blossom, we should rerun the windows test since it shouldn't have failed given these changes.

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 25, 2025

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants