Skip to content

Commit ba1724b

Browse files
committed
move nan check to last diffusion step.
1 parent bfd776e commit ba1724b

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

sbi/samplers/score/diffuser.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,20 @@ def run(
167167
# Apply predictor step
168168
samples = self.predictor(samples, t_current, t_next)
169169

170-
# Check for NaN values after predictor
171-
if torch.isnan(samples).any():
172-
raise RuntimeError(
173-
f"NaN values detected after predictor step "
174-
f"{time_step_idx}/{total_time_steps}. "
175-
f"This may indicate numerical instability in the vector field."
176-
)
177-
178170
# Apply corrector step if available
179171
if self.corrector is not None:
180172
samples = self.corrector(samples, t_next, t_current)
181173

182174
if save_intermediate:
183175
intermediate_samples.append(samples)
184176

177+
# Check for NaN values after predictor
178+
if torch.isnan(samples).any():
179+
raise RuntimeError(
180+
"NaN values detected after diffusion sampling "
181+
"This may indicate numerical instability in the vector field."
182+
)
183+
185184
if save_intermediate:
186185
return torch.cat(intermediate_samples, dim=0)
187186
else:

0 commit comments

Comments
 (0)