Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ax.adapter.transforms.base import Transform
from ax.core import Experiment, ObservationFeatures
from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.batch_trial import BatchTrial
Expand Down Expand Up @@ -78,6 +79,31 @@
from torch import Tensor


def analysis_card_to_dict(card: AnalysisCard) -> dict[str, Any]:
"""Convert Ax analysis card to a dictionary."""
return {
"__type": card.__class__.__name__,
"name": card.name,
"title": card.title,
"subtitle": card.subtitle,
"df": card.df,
"blob": card.blob,
"timestamp": card._timestamp,
}


def analysis_card_group_to_dict(group: AnalysisCardGroup) -> dict[str, Any]:
"""Convert Ax analysis card group to a dictionary."""
return {
"__type": "AnalysisCardGroup",
"name": group.name,
"title": group.title,
"subtitle": group.subtitle,
"children": group.children,
"timestamp": group._timestamp,
}


def experiment_to_dict(experiment: Experiment) -> dict[str, Any]:
"""Convert Ax experiment to a dictionary."""
return {
Expand Down
21 changes: 21 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from ax.adapter.base import DataLoaderConfig
from ax.adapter.registry import GeneratorRegistryBase, Generators
from ax.adapter.transforms.base import Transform
from ax.analysis.graphviz.graphviz_analysis import GraphvizAnalysisCard
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_metric import (
BenchmarkMapMetric,
Expand All @@ -24,6 +28,7 @@
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata
from ax.core import Experiment, ObservationFeatures
from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup, ErrorAnalysisCard
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.batch_trial import AbandonedArm, BatchTrial
Expand Down Expand Up @@ -115,6 +120,8 @@
transform_type_from_json,
)
from ax.storage.json_store.encoders import (
analysis_card_group_to_dict,
analysis_card_to_dict,
arm_to_dict,
auxiliary_experiment_to_dict,
backend_simulator_to_dict,
Expand Down Expand Up @@ -185,6 +192,8 @@


CORE_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = {
AnalysisCard: analysis_card_to_dict,
AnalysisCardGroup: analysis_card_group_to_dict,
Arm: arm_to_dict,
AuxiliaryExperiment: auxiliary_experiment_to_dict,
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
Expand All @@ -203,21 +212,25 @@
Data: data_to_dict,
ExpressionDerivedMetric: metric_to_dict,
DerivedParameter: derived_parameter_to_dict,
ErrorAnalysisCard: analysis_card_to_dict,
Experiment: experiment_to_dict,
FactorialMetric: metric_to_dict,
FixedParameter: fixed_parameter_to_dict,
GammaPrior: botorch_component_to_dict,
GraphvizAnalysisCard: analysis_card_to_dict,
GenerationStep: generation_node_to_dict,
GenerationNode: generation_node_to_dict,
GenerationStrategy: generation_strategy_to_dict,
GeneratorRun: generator_run_to_dict,
Hartmann6Metric: metric_to_dict,
HealthcheckAnalysisCard: analysis_card_to_dict,
ImprovementGlobalStoppingStrategy: improvement_global_stopping_strategy_to_dict,
Interval: botorch_component_to_dict,
IsSingleObjective: transition_criterion_to_dict,
L2NormMetric: metric_to_dict,
LogNormalPrior: botorch_component_to_dict,
MapMetric: metric_to_dict,
MarkdownAnalysisCard: analysis_card_to_dict,
MaxGenerationParallelism: pausing_criterion_to_dict,
MaxTrialsAwaitingData: pausing_criterion_to_dict,
Metric: metric_to_dict,
Expand All @@ -230,6 +243,7 @@
Normalize: botorch_component_to_dict,
FilterFeatures: botorch_component_to_dict,
PercentileEarlyStoppingStrategy: percentile_early_stopping_strategy_to_dict,
PlotlyAnalysisCard: analysis_card_to_dict,
SklearnMetric: metric_to_dict,
ChemistryMetric: metric_to_dict,
NegativeBraninMetric: metric_to_dict,
Expand Down Expand Up @@ -283,6 +297,8 @@
# splattable inputs to the resultant class, not just Types with kwarg inits.
CORE_DECODER_REGISTRY: TDecoderRegistry = {
"AbandonedArm": AbandonedArm,
"AnalysisCard": AnalysisCard,
"AnalysisCardGroup": AnalysisCardGroup,
"AndEarlyStoppingStrategy": AndEarlyStoppingStrategy,
"AutoTransitionAfterGen": AutoTransitionAfterGen,
"AuxiliaryExperiment": AuxiliaryExperiment,
Expand Down Expand Up @@ -318,19 +334,22 @@
"ExpressionDerivedMetric": ExpressionDerivedMetric,
"DerivedParameter": DerivedParameter,
"DomainType": DomainType,
"ErrorAnalysisCard": ErrorAnalysisCard,
"Experiment": Experiment,
"ExperimentStatus": ExperimentStatus,
"FactorialMetric": FactorialMetric,
"FilterFeatures": FilterFeatures,
"FixedParameter": fixed_parameter_from_json,
"GammaPrior": GammaPrior,
"GraphvizAnalysisCard": GraphvizAnalysisCard,
"GenerationNode": GenerationNode,
"GenerationStrategy": GenerationStrategy,
"GenerationStep": GenerationStep,
"GeneratorRun": GeneratorRun,
"Generators": Generators,
"GeneratorSpec": GeneratorSpec,
"Hartmann6Metric": Hartmann6Metric,
"HealthcheckAnalysisCard": HealthcheckAnalysisCard,
"HierarchicalSearchSpace": HierarchicalSearchSpace,
"ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy,
"InputConstructorPurpose": InputConstructorPurpose,
Expand All @@ -345,6 +364,7 @@
"LogNormalPrior": LogNormalPrior,
"MapData": Data,
"MapMetric": MapMetric,
"MarkdownAnalysisCard": MarkdownAnalysisCard,
"MaxTrials": MinTrials,
"MaxGenerationParallelism": MaxGenerationParallelism,
"MaxTrialsAwaitingData": MaxTrialsAwaitingData,
Expand Down Expand Up @@ -382,6 +402,7 @@
"PurePosixPath": pathlib_from_json,
"PureWindowsPath": pathlib_from_json,
"PercentileEarlyStoppingStrategy": percentile_early_stopping_strategy_from_json,
"PlotlyAnalysisCard": PlotlyAnalysisCard,
"RangeParameter": RangeParameter,
"ReductionCriterion": ReductionCriterion,
"Round": Round,
Expand Down
89 changes: 89 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from ax.adapter.transforms.base import Transform
from ax.adapter.transforms.log import Log
from ax.adapter.transforms.one_hot import OneHot
from ax.analysis.graphviz.graphviz_analysis import GraphvizAnalysisCard
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.benchmark.methods.sobol import get_sobol_benchmark_method
from ax.benchmark.testing.benchmark_stubs import (
get_aggregated_benchmark_result,
Expand All @@ -33,6 +37,7 @@
get_benchmark_result,
get_benchmark_time_varying_metric,
)
from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup, ErrorAnalysisCard
from ax.core.auxiliary import AuxiliaryExperimentPurpose
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun
Expand Down Expand Up @@ -406,6 +411,90 @@
("ThresholdEarlyStoppingStrategy", get_threshold_early_stopping_strategy),
("Trial", get_trial),
("WinsorizationConfig", get_winsorization_config),
(
"AnalysisCard",
lambda: AnalysisCard(
name="TestAnalysis",
title="Test",
subtitle="subtitle",
df=pd.DataFrame({"a": [1, 2]}),
blob="blob_str",
),
),
(
"ErrorAnalysisCard",
lambda: ErrorAnalysisCard(
name="TestError",
title="Error",
subtitle="err subtitle",
df=pd.DataFrame(),
blob="error details",
),
),
(
"PlotlyAnalysisCard",
lambda: PlotlyAnalysisCard(
name="TestPlotly",
title="Plot",
subtitle="plot subtitle",
df=pd.DataFrame({"x": [1]}),
blob="{}",
),
),
(
"MarkdownAnalysisCard",
lambda: MarkdownAnalysisCard(
name="TestMd",
title="MD",
subtitle="md subtitle",
df=pd.DataFrame(),
blob="# Hello",
),
),
(
"HealthcheckAnalysisCard",
lambda: HealthcheckAnalysisCard(
name="TestHC",
title="HC",
subtitle="hc subtitle",
df=pd.DataFrame(),
blob='{"status": 0}',
),
),
(
"GraphvizAnalysisCard",
lambda: GraphvizAnalysisCard(
name="TestGV",
title="GV",
subtitle="gv subtitle",
df=pd.DataFrame(),
blob="digraph {}",
),
),
(
"AnalysisCardGroup",
lambda: AnalysisCardGroup(
name="TestGroup",
title="Group",
subtitle="group subtitle",
children=[
AnalysisCard(
name="Child",
title="C1",
subtitle="s1",
df=pd.DataFrame({"a": [1]}),
blob="b1",
),
MarkdownAnalysisCard(
name="Child2",
title="C2",
subtitle="s2",
df=pd.DataFrame(),
blob="# md",
),
],
),
),
]


Expand Down
Loading