Skip to content

Commit b7c4acc

Browse files
authored
docs: warn when using append_simulations with exclude_invalid_x=True. (#1486)
Update SNRE to NRE and SNLE to NLE in append_simulations docstring. Fixes: #1427
1 parent c5c511c commit b7c4acc

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

sbi/inference/trainers/nle/nle_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from abc import ABC
56
from copy import deepcopy
67
from typing import Any, Callable, Dict, Optional, Union
@@ -94,11 +95,11 @@ def append_simulations(
9495
theta: Parameter sets.
9596
x: Simulation outputs.
9697
exclude_invalid_x: Whether invalid simulations are discarded during
97-
training. If `False`, SNLE raises an error when invalid simulations are
98+
training. If `False`, NLE raises an error when invalid simulations are
9899
found. If `True`, invalid simulations are discarded and training
99100
can proceed, but this gives systematically wrong results.
100101
from_round: Which round the data stemmed from. Round 0 means from the prior.
101-
With default settings, this is not used at all for `SNLE`. Only when
102+
With default settings, this is not used at all for `NLE`. Only when
102103
the user later on requests `.train(discard_prior_samples=True)`, we
103104
use these indices to find which training data stemmed from the prior.
104105
data_device: Where to store the data, default is on the same device where
@@ -108,6 +109,11 @@ def append_simulations(
108109
NeuralInference object (returned so that this function is chainable).
109110
"""
110111

112+
if exclude_invalid_x:
113+
warnings.warn(
114+
"NLE gives systematically wrong results when exclude_invalid_x=True.",
115+
stacklevel=2,
116+
)
111117
# pyright false positive, will be fixed with pyright 1.1.310
112118
return super().append_simulations( # type: ignore
113119
theta=theta,

sbi/inference/trainers/nre/nre_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import warnings
45
from abc import ABC, abstractmethod
56
from copy import deepcopy
67
from typing import Any, Callable, Dict, Optional, Union
@@ -104,11 +105,11 @@ def append_simulations(
104105
theta: Parameter sets.
105106
x: Simulation outputs.
106107
exclude_invalid_x: Whether invalid simulations are discarded during
107-
training. If `False`, SNRE raises an error when invalid simulations are
108+
training. If `False`, NRE raises an error when invalid simulations are
108109
found. If `True`, invalid simulations are discarded and training
109110
can proceed, but this gives systematically wrong results.
110111
from_round: Which round the data stemmed from. Round 0 means from the prior.
111-
With default settings, this is not used at all for `SNRE`. Only when
112+
With default settings, this is not used at all for `NRE`. Only when
112113
the user later on requests `.train(discard_prior_samples=True)`, we
113114
use these indices to find which training data stemmed from the prior.
114115
data_device: Where to store the data, default is on the same device where
@@ -117,6 +118,11 @@ def append_simulations(
117118
Returns:
118119
NeuralInference object (returned so that this function is chainable).
119120
"""
121+
if exclude_invalid_x:
122+
warnings.warn(
123+
"NRE gives systematically wrong results when exclude_invalid_x=True.",
124+
stacklevel=2,
125+
)
120126

121127
return super().append_simulations( # type: ignore
122128
theta=theta,

0 commit comments

Comments
 (0)