Skip to content

Commit 8f68733

Browse files
JasonKChowmeta-codesync[bot]
authored andcommitted
Remove extra normalize in pairwise probit (facebookresearch#823)
Summary: Pull Request resolved: facebookresearch#823 This diff modifies the `PairwiseProbitModel` by removing an extra normalization step in the `pairwise_probit.py` file. The normalization should now be handled by transforms. Additionally, the `test_pairwise_probit.py` file is updated to reflect this change. Reviewed By: phigua Differential Revision: D84107038 fbshipit-source-id: 6b253c150126174a2ee1eaa7a7000c68e7348b12
1 parent d82d5c2 commit 8f68733

2 files changed

Lines changed: 36 additions & 14 deletions

File tree

aepsych/models/pairwise_probit.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from aepsych.utils_logging import getLogger
1818
from botorch.fit import fit_gpytorch_mll
1919
from botorch.models import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
20-
from botorch.models.transforms.input import Normalize
2120
from torch.distributions import Normal
2221

2322
logger = getLogger()
@@ -93,8 +92,6 @@ def __init__(
9392

9493
self.max_fit_time = max_fit_time
9594

96-
bounds = torch.stack((self.lb, self.ub))
97-
input_transform = Normalize(d=dim, bounds=bounds)
9895
if covar_module is None:
9996
factory = DefaultMeanCovarFactory(
10097
dim=dim, stimuli_per_trial=self.stimuli_per_trial
@@ -106,7 +103,6 @@ def __init__(
106103
comparisons=None,
107104
covar_module=covar_module,
108105
jitter=1e-3,
109-
input_transform=input_transform,
110106
)
111107

112108
self.dim = dim # The Pairwise constructor sets self.dim = None.

tests/models/test_pairwise_probit.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,14 @@ def test_pairwise_memorize_rescaled(self):
183183
xrescaled[:, 0, :] = xrescaled[:, 0, :] / 500 + 1
184184
xrescaled[:, 1, :] = xrescaled[:, 1, :] / 5e-6 - 1
185185
y = torch.Tensor(f_pairwise(new_novel_det, xrescaled) > 0.5).int()
186-
model = PairwiseProbitModel(lb=lb, ub=ub)
186+
lb_tensor = torch.tensor(lb)
187+
ub_tensor = torch.tensor(ub)
188+
transforms = ParameterTransforms(
189+
normalize=NormalizeScale(d=2, bounds=torch.stack([lb_tensor, ub_tensor]))
190+
)
191+
model = ParameterTransformedModel(
192+
PairwiseProbitModel, lb=lb_tensor, ub=ub_tensor, transforms=transforms
193+
)
187194
model.fit(x[:18], y[:18])
188195
with torch.no_grad():
189196
f0, _ = model.predict(x[18:, ..., 0])
@@ -341,29 +348,48 @@ def test_2d_pairwise_probit(self):
341348
ub = torch.tensor([1.0, 1.0])
342349
extra_acqf_args = {"beta": 3.84}
343350

351+
transforms = ParameterTransforms(
352+
normalize=NormalizeScale(d=2, bounds=torch.stack([lb, ub]))
353+
)
354+
sobol_gen = ParameterTransformedGenerator(
355+
generator=SobolGenerator,
356+
lb=lb,
357+
ub=ub,
358+
seed=seed,
359+
stimuli_per_trial=2,
360+
transforms=transforms,
361+
)
362+
acqf_gen = ParameterTransformedGenerator(
363+
generator=OptimizeAcqfGenerator,
364+
acqf=qUpperConfidenceBound,
365+
acqf_kwargs=extra_acqf_args,
366+
stimuli_per_trial=2,
367+
transforms=transforms,
368+
lb=lb,
369+
ub=ub,
370+
)
371+
probit_model = ParameterTransformedModel(
372+
model=PairwiseProbitModel, lb=lb, ub=ub, transforms=transforms
373+
)
344374
model_list = [
345375
Strategy(
346376
lb=lb,
347377
ub=ub,
348-
generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2),
378+
generator=sobol_gen,
349379
min_asks=n_init,
350380
stimuli_per_trial=2,
351381
outcome_types=["binary"],
382+
transforms=transforms,
352383
),
353384
Strategy(
354385
lb=lb,
355386
ub=ub,
356-
model=PairwiseProbitModel(lb=lb, ub=ub),
357-
generator=OptimizeAcqfGenerator(
358-
acqf=qUpperConfidenceBound,
359-
acqf_kwargs=extra_acqf_args,
360-
stimuli_per_trial=2,
361-
lb=lb,
362-
ub=ub,
363-
),
387+
model=probit_model,
388+
generator=acqf_gen,
364389
min_asks=n_opt,
365390
stimuli_per_trial=2,
366391
outcome_types=["binary"],
392+
transforms=transforms,
367393
),
368394
]
369395

0 commit comments

Comments
 (0)