diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index e3b9f6db83..4bfc00690c 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -192,7 +192,9 @@ def main(cfg: DictConfig) -> None: patching=patching, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial(stochastic_sampler, patching=patching) + sampler_fn = partial( + stochastic_sampler, patching=patching, num_steps=cfg.sampler.num_steps + ) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}")