Skip to content

Commit f3222f9

Browse files
satwikspsjanfb
andauthored
fix: device mismatch in NPSE marginal mean/std computation (#1707)
* Fix device mismatch in NPSE marginal mean/std computation * Apply pre-commit formatting to score_estimator fix * Inline device assignment for t_tensor as suggested by reviewer Co-authored-by: Jan Teusen (né Boelts) <janfb@users.noreply.github.com> --------- Co-authored-by: Jan Teusen (né Boelts) <janfb@users.noreply.github.com>
1 parent 2c216c2 commit f3222f9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sbi/neural_nets/estimators/score_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ def __init__(
118118

119119
# Now that input_shape and mean_0, std_0 is set, we can compute the proper mean
120120
# and std for the "base" distribution.
121-
mean_t = self.approx_marginal_mean(torch.tensor([t_max]))
122-
std_t = self.approx_marginal_std(torch.tensor([t_max]))
121+
# Create t on the correct device to avoid CPU/GPU mismatch
122+
t_tensor = torch.as_tensor([t_max], device=self.mean_0.device)
123+
mean_t = self.approx_marginal_mean(t_tensor)
124+
std_t = self.approx_marginal_std(t_tensor)
123125
mean_t = torch.broadcast_to(mean_t, (1, *input_shape))
124126
std_t = torch.broadcast_to(std_t, (1, *input_shape))
125127

0 commit comments

Comments
 (0)