Skip to content

Commit c6a1770

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Centralize experiment property serialization for JSON storage (#5113)
Summary: Pull Request resolved: #5113 `update_properties_on_experiment` writes `experiment._properties` directly to the DB via raw `json.dumps`, bypassing the `experiment_to_sqa` encoder that converts non-JSON-serializable objects (like `LLMMessage` dataclasses) to plain dicts. This causes `TypeError: Object of type LLMMessage is not JSON serializable` for any experiment with LLM messages saved through the incremental property update path. This diff extracts the property preparation logic into a shared `prepare_experiment_properties_for_storage` function called from both `experiment_to_sqa` and `update_properties_on_experiment`, ensuring all save paths use consistent serialization. We considered adding a blanket `default=dataclasses.asdict` handler to the `JSONEncodedObject` TypeDecorator, but rejected it because: (1) dataclasses with non-JSON fields (e.g. `BenchmarkTrialMetadata` with `pd.DataFrame`) would produce confusing partial-conversion errors, (2) silently serializing dataclasses without paired decoders breaks round-trip fidelity, and (3) it removes the fail-fast guardrail for unexpected objects in JSON columns. Differential Revision: D98828500 fbshipit-source-id: ed458af8647d30ac018e885e00a3466daeec6104
1 parent 8ec8faa commit c6a1770

3 files changed

Lines changed: 59 additions & 14 deletions

File tree

ax/storage/sqa_store/encoder.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,31 @@
8888
logger: Logger = get_logger(__name__)
8989

9090

91+
def prepare_experiment_properties_for_storage(
92+
experiment: Experiment,
93+
) -> dict[str, Any]:
94+
"""Prepare experiment properties for JSON storage by converting non-JSON-
95+
serializable objects (e.g. dataclasses) to plain dicts.
96+
97+
This is the single source of truth for experiment property serialization.
98+
All code paths that persist experiment properties to the database should
99+
use this function to ensure consistent handling.
100+
"""
101+
properties = experiment._properties.copy()
102+
if (
103+
oc := experiment.optimization_config
104+
) is not None and oc.pruning_target_parameterization is not None:
105+
properties["pruning_target_parameterization"] = arm_to_dict(
106+
oc.pruning_target_parameterization
107+
)
108+
if Keys.LLM_MESSAGES in properties:
109+
properties[Keys.LLM_MESSAGES] = [
110+
dataclasses.asdict(m) if isinstance(m, LLMMessage) else m
111+
for m in properties[Keys.LLM_MESSAGES]
112+
]
113+
return properties
114+
115+
91116
class Encoder:
92117
"""Class that contains methods for storing an Ax experiment to SQLAlchemy.
93118
@@ -230,18 +255,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
230255
]
231256
elif experiment.runner:
232257
runners.append(self.runner_to_sqa(none_throws(experiment.runner)))
233-
properties = experiment._properties.copy()
234-
if (
235-
oc := experiment.optimization_config
236-
) is not None and oc.pruning_target_parameterization is not None:
237-
properties["pruning_target_parameterization"] = arm_to_dict(
238-
oc.pruning_target_parameterization
239-
)
240-
if Keys.LLM_MESSAGES in properties:
241-
properties[Keys.LLM_MESSAGES] = [
242-
dataclasses.asdict(m) if isinstance(m, LLMMessage) else m
243-
for m in properties[Keys.LLM_MESSAGES]
244-
]
258+
properties = prepare_experiment_properties_for_storage(experiment)
245259

246260
# pyre-ignore[9]: Expected `Base` for 1st...yping.Type[Experiment]`.
247261
experiment_class: type[SQAExperiment] = self.config.class_to_sqa_class[

ax/storage/sqa_store/save.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
from ax.storage.sqa_store import validation as _validation_listeners # noqa: F401
3030
from ax.storage.sqa_store.db import session_scope, SQABase
3131
from ax.storage.sqa_store.decoder import Decoder
32-
from ax.storage.sqa_store.encoder import Encoder
32+
from ax.storage.sqa_store.encoder import (
33+
Encoder,
34+
prepare_experiment_properties_for_storage,
35+
)
3336
from ax.storage.sqa_store.sqa_classes import (
3437
SQAData,
3538
SQAExperiment,
@@ -540,10 +543,14 @@ def update_properties_on_experiment(
540543

541544
exp_id = _assert_experiment_saved(experiment_with_updated_properties)
542545

546+
properties = prepare_experiment_properties_for_storage(
547+
experiment_with_updated_properties
548+
)
549+
543550
with session_scope() as session:
544551
session.query(exp_sqa_class).filter_by(id=exp_id).update(
545552
{
546-
"properties": experiment_with_updated_properties._properties,
553+
"properties": properties,
547554
}
548555
)
549556

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ax.core.experiment import Experiment
3535
from ax.core.experiment_status import ExperimentStatus
3636
from ax.core.generator_run import GeneratorRun
37+
from ax.core.llm_provider import LLMMessage
3738
from ax.core.metric import Metric
3839
from ax.core.multi_type_experiment import MultiTypeExperiment
3940
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
@@ -2863,6 +2864,29 @@ def test_set_immutable_search_space_and_opt_config(self) -> None:
28632864
loaded_experiment = load_experiment(experiment.name)
28642865
self.assertTrue(loaded_experiment.immutable_search_space_and_opt_config)
28652866

2867+
def test_update_properties_on_experiment_with_llm_messages(self) -> None:
2868+
"""Test that LLMMessage objects in experiment properties are correctly
2869+
serialized through the incremental update_properties_on_experiment path,
2870+
not just the full experiment_to_sqa path."""
2871+
experiment = get_experiment_with_batch_trial()
2872+
save_experiment(experiment)
2873+
2874+
messages = [
2875+
LLMMessage(role="system", content="You are helpful."),
2876+
LLMMessage(role="user", content="Hello", metadata={"key": "val"}),
2877+
]
2878+
experiment.llm_messages = messages
2879+
update_properties_on_experiment(
2880+
experiment_with_updated_properties=experiment,
2881+
)
2882+
2883+
loaded_experiment = load_experiment(experiment.name)
2884+
self.assertEqual(len(loaded_experiment.llm_messages), 2)
2885+
self.assertEqual(loaded_experiment.llm_messages[0].role, "system")
2886+
self.assertEqual(loaded_experiment.llm_messages[0].content, "You are helpful.")
2887+
self.assertEqual(loaded_experiment.llm_messages[1].role, "user")
2888+
self.assertEqual(loaded_experiment.llm_messages[1].metadata, {"key": "val"})
2889+
28662890
def test_update_properties_on_trial(self) -> None:
28672891
experiment = get_experiment_with_batch_trial()
28682892
self.assertNotIn("foo", experiment.trials[0]._properties)

0 commit comments

Comments
 (0)