From 16f46f7c8f165c151cb6d62213468762956d83f4 Mon Sep 17 00:00:00 2001 From: Younes Abid Date: Mon, 2 Feb 2026 15:15:16 +0400 Subject: [PATCH 1/2] Fix missing num_steps parameter for stochastic sampler - Add missing num_steps=cfg.sampler.num_steps parameter to stochastic_sampler partial() call - This bug caused stochastic sampler to always use default 18 steps instead of configured num_steps - Fixes inconsistency with deterministic sampler which correctly passes num_steps parameter - Improves performance when using fewer diffusion steps (e.g., num_steps=2) --- examples/weather/corrdiff/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index 23ba63d5ac..b909cee94f 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -192,7 +192,7 @@ def main(cfg: DictConfig) -> None: patching=patching, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial(stochastic_sampler, patching=patching) + sampler_fn = partial(stochastic_sampler, patching=patching, num_steps=cfg.sampler.num_steps) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}") From 54cd6eb91b607138018ea1092c9ccf8349c5a1fe Mon Sep 17 00:00:00 2001 From: Younes Abid Date: Mon, 16 Feb 2026 09:45:16 +0400 Subject: [PATCH 2/2] Format code with pre-commit hooks - Applied ruff formatting to generate.py - Ensures code meets project style standards --- examples/weather/corrdiff/generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index b909cee94f..933a90e213 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -192,7 +192,9 @@ def main(cfg: DictConfig) -> None: patching=patching, ) elif cfg.sampler.type == "stochastic": - sampler_fn = partial(stochastic_sampler, patching=patching, num_steps=cfg.sampler.num_steps) + sampler_fn = partial( + stochastic_sampler, patching=patching, num_steps=cfg.sampler.num_steps + ) else: raise ValueError(f"Unknown sampling method {cfg.sampling.type}")