Skip to content

Commit c1aa7dd

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Add lean botorch targets to avoid JAX/numpyro dependency (#3277)
Summary: Pull Request resolved: #3277 Reviewed By: saitcakmak Differential Revision: D101239631 fbshipit-source-id: 93016efa79a6c1eb326a1648703862cbf9162d1f
1 parent ca117dd commit c1aa7dd

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

botorch/acquisition/objective.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def forward(self, posterior: Posterior, X: Tensor | None = None) -> Posterior:
6464
pass # pragma: no cover
6565

6666

67-
# import DeterministicModel after PosteriorTransform to avoid circular import
68-
from botorch.models.deterministic import DeterministicModel # noqa
67+
def _is_deterministic_model(model: Model) -> bool:
68+
from botorch.models.deterministic import DeterministicModel
69+
70+
return isinstance(model, DeterministicModel)
6971

7072

7173
class ScalarizedPosteriorTransform(PosteriorTransform):
@@ -527,7 +529,7 @@ def __init__(
527529
"""
528530
super().__init__()
529531
self.pref_model = pref_model
530-
if isinstance(pref_model, DeterministicModel):
532+
if _is_deterministic_model(pref_model):
531533
if sample_shape is not None:
532534
raise ValueError("sample_shape must be None for DeterministicModel.")
533535
self.sampler = None
@@ -566,7 +568,7 @@ def forward(self, samples: Tensor, X: Tensor | None = None) -> Tensor:
566568
raise ValueError("samples should have at least 3 dimensions.")
567569

568570
posterior = self.pref_model.posterior(samples)
569-
if isinstance(self.pref_model, DeterministicModel):
571+
if _is_deterministic_model(self.pref_model):
570572
# return preference posterior mean
571573
return posterior.mean.squeeze(-1)
572574
else:

0 commit comments

Comments
 (0)