Skip to content

Commit d82d5c2

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Add log_post_var strategy arg (facebookresearch#822)
Summary: Pull Request resolved: facebookresearch#822 A new strategy option that uses the same computation as min_post_range but instead of deciding when to finish just logs prediction variance to see the progress of the model. Reviewed By: tymmsc Differential Revision: D83012919 fbshipit-source-id: 65097f611d9563742adda3d377ab62c3e7308ec8
1 parent 9a3a6cb commit d82d5c2

2 files changed

Lines changed: 55 additions & 7 deletions

File tree

aepsych/strategy/strategy.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
max_asks: int | None = None,
5454
keep_most_recent: int | None = None,
5555
min_post_range: float | None = None,
56+
log_post_var: bool = False,
5657
name: str = "",
5758
run_indefinitely: bool = False,
5859
transforms: ChainedInputTransform = ChainedInputTransform(**{}),
@@ -82,6 +83,7 @@ def __init__(
8283
as data collected from later trials. When None, the model is fitted on all data.
8384
min_post_range (float, optional): Experimental. The required difference between the posterior's minimum and maximum value in
8485
probablity space before the strategy will finish. Ignored if None (default).
86+
log_post_var (bool): Whether to log the posterior prediction variance (as it would be used for min_post_range). Defaults to False.
8587
name (str): The name of the strategy. Defaults to the empty string.
8688
run_indefinitely (bool): If true, the strategy will run indefinitely until finish() is explicitly called. Other stopping criteria will
8789
be ignored. Defaults to False.
@@ -161,8 +163,11 @@ def __init__(
161163
self.transforms = transforms
162164

163165
self.min_post_range = min_post_range
164-
if self.min_post_range is not None:
165-
assert model is not None, "min_post_range must be None if model is None!"
166+
self.log_post_var = log_post_var
167+
if self.min_post_range is not None or self.log_post_var:
168+
assert (
169+
model is not None
170+
), "posterior range cannot be evaluated if model is None!"
166171
self.eval_grid = make_scaled_sobol(
167172
lb=self.lb, ub=self.ub, size=self._n_eval_points
168173
)
@@ -429,16 +434,23 @@ def finished(self) -> bool:
429434
else:
430435
sufficient_outcomes = True
431436

432-
if self.min_post_range is not None:
437+
if self.min_post_range is not None or self.log_post_var:
433438
assert (
434439
self.model is not None
435440
), "model is None! Cannot predict without a model!"
436-
fmean, _ = self.model.predict(self.eval_grid, probability_space=True)
437-
meets_post_range = bool(
438-
((fmean.max() - fmean.min()) >= self.min_post_range).item()
439-
)
441+
fmean, fvar = self.model.predict(self.eval_grid, probability_space=True)
442+
post_range = fmean.max() - fmean.min()
443+
else:
444+
post_range = None
445+
446+
if post_range is not None:
447+
logger.info(f"Mean posterior variance = {fvar.mean().item()}")
448+
449+
if self.min_post_range:
450+
meets_post_range = bool((post_range >= self.min_post_range).item())
440451
else:
441452
meets_post_range = True
453+
442454
finished = (
443455
self._count >= self.min_asks
444456
and self.n >= self.min_total_tells

tests/test_strategy.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,42 @@ def test_run_indefinitely(self):
227227
self.strat.finish()
228228
self.assertTrue(self.strat.finished)
229229

230+
def test_log_post_var(self):
231+
seed = 1
232+
torch.manual_seed(seed)
233+
np.random.seed(seed)
234+
lb = [0]
235+
ub = [1]
236+
237+
self.strat = Strategy(
238+
model=GPClassificationModel(
239+
dim=1,
240+
),
241+
generator=SobolGenerator(lb=lb, ub=ub),
242+
min_asks=10,
243+
lb=lb,
244+
ub=ub,
245+
stimuli_per_trial=1,
246+
outcome_types=["binary"],
247+
log_post_var=True,
248+
)
249+
250+
# Add some initial data
251+
for _ in range(5):
252+
points = self.strat.gen(1)
253+
response = int(np.random.rand() < 0.5)
254+
self.strat.add_data(points, torch.tensor([response]))
255+
256+
# Check that the log prints the expected message when checking finished status
257+
with self.assertLogs() as log:
258+
_ = self.strat.finished
259+
260+
# Look for the log message about mean posterior variance
261+
log_found = any("Mean posterior variance" in msg for msg in log.output)
262+
self.assertTrue(
263+
log_found, "Expected log message about mean posterior variance not found"
264+
)
265+
230266
def test_batchsobol_pairwise(self):
231267
lb = [1, 2, 3]
232268
ub = [2, 3, 4]

0 commit comments

Comments
 (0)