Skip to content

Commit 63346cc

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Fix prior deserialization for priors with buffered attributes (#5167)
Summary: Pull Request resolved: #5167 The Ax JSON decoder's `botorch_component_from_json` strips the `BUFFERED_PREFIX` from state_dict keys only for `TransformedDistribution` subclasses. This misses priors like `BetaPrior` whose underlying distribution (`Beta`) uses `property` descriptors delegating to an internal `Dirichlet`, causing `_bufferize_attributes` to use the prefix. Broaden the check from `TransformedDistribution` to `(TransformedDistribution, Prior)` so all gpytorch priors with buffered attributes deserialize correctly. Reviewed By: sdaulton Differential Revision: D100341242 fbshipit-source-id: 26b2a047b48d3e1fa9a1d4585eb90cdeb1f58b62
1 parent 89c2c67 commit 63346cc

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

ax/storage/json_store/decoders.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@
5454
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
5555
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
5656
from botorch.utils.types import _DefaultType, DEFAULT
57+
from gpytorch.priors import Prior
5758
from gpytorch.priors.utils import BUFFERED_PREFIX
5859
from pyre_extensions import assert_is_instance
59-
from torch.distributions.transformed_distribution import TransformedDistribution
6060

6161
logger: logging.Logger = get_logger(__name__)
6262

@@ -369,8 +369,11 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) ->
369369
for k, v in state_dict.items()
370370
}
371371
)
372-
if issubclass(botorch_class, TransformedDistribution):
373-
# Extract the buffered attributes for transformed priors.
372+
if issubclass(botorch_class, Prior):
373+
# Extract the buffered attributes for priors. Some priors (e.g.
374+
# BetaPrior, LogNormalPrior) store parameters with BUFFERED_PREFIX
375+
# because their underlying distribution uses @property descriptors
376+
# that cannot be deleted by _bufferize_attributes.
374377
for k in list(state_dict.keys()):
375378
if k.startswith(BUFFERED_PREFIX):
376379
state_dict[k[len(BUFFERED_PREFIX) :]] = state_dict.pop(k)

ax/storage/json_store/tests/test_json_store.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,30 @@ def test_BadStateDict(self) -> None:
10691069
del expected_json["state_dict"]["lower_bound"]
10701070
botorch_component_from_json(interval.__class__, expected_json)
10711071

1072+
def test_prior_roundtrip_serialization(self) -> None:
1073+
"""Test encode/decode roundtrip for priors with buffered attributes.
1074+
1075+
Priors whose underlying distribution uses @property descriptors
1076+
(e.g. BetaPrior via Dirichlet, LogNormalPrior via TransformedDistribution)
1077+
store state_dict keys with BUFFERED_PREFIX. The decoder must strip
1078+
the prefix to match __init__ arg names.
1079+
"""
1080+
from botorch.models.utils.priors import BetaPrior
1081+
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior, NormalPrior
1082+
1083+
priors = [
1084+
("BetaPrior", BetaPrior(concentration1=2.5, concentration0=1.5)),
1085+
("GammaPrior", GammaPrior(concentration=2.0, rate=1.0)),
1086+
("NormalPrior", NormalPrior(loc=0.0, scale=1.0)),
1087+
("LogNormalPrior", LogNormalPrior(loc=0.0, scale=1.0)),
1088+
]
1089+
for name, prior in priors:
1090+
with self.subTest(prior=name):
1091+
encoded = botorch_component_to_dict(prior)
1092+
decoded = botorch_component_from_json(prior.__class__, encoded)
1093+
self.assertIsInstance(decoded, prior.__class__)
1094+
self.assertEqual(decoded.state_dict(), prior.state_dict())
1095+
10721096
def test_observation_features_backward_compatibility(self) -> None:
10731097
json = {
10741098
"__type": "ObservationFeatures",

0 commit comments

Comments
 (0)