diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 0d49f1abe20..fe62efec79b 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -12,6 +12,7 @@ from ax.adapter.transforms.base import Transform from ax.core import Experiment, ObservationFeatures +from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment from ax.core.batch_trial import BatchTrial @@ -78,6 +79,31 @@ from torch import Tensor +def analysis_card_to_dict(card: AnalysisCard) -> dict[str, Any]: + """Convert Ax analysis card to a dictionary.""" + return { + "__type": card.__class__.__name__, + "name": card.name, + "title": card.title, + "subtitle": card.subtitle, + "df": card.df, + "blob": card.blob, + "timestamp": card._timestamp, + } + + +def analysis_card_group_to_dict(group: AnalysisCardGroup) -> dict[str, Any]: + """Convert Ax analysis card group to a dictionary.""" + return { + "__type": "AnalysisCardGroup", + "name": group.name, + "title": group.title, + "subtitle": group.subtitle, + "children": group.children, + "timestamp": group._timestamp, + } + + def experiment_to_dict(experiment: Experiment) -> dict[str, Any]: """Convert Ax experiment to a dictionary.""" return { diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 3a7f4fcec97..0036643e80a 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -14,6 +14,10 @@ from ax.adapter.base import DataLoaderConfig from ax.adapter.registry import GeneratorRegistryBase, Generators from ax.adapter.transforms.base import Transform +from ax.analysis.graphviz.graphviz_analysis import GraphvizAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.benchmark.benchmark_method import BenchmarkMethod from ax.benchmark.benchmark_metric import ( BenchmarkMapMetric, @@ -24,6 +28,7 @@ from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.benchmark_trial_metadata import BenchmarkTrialMetadata from ax.core import Experiment, ObservationFeatures +from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup, ErrorAnalysisCard from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.batch_trial import AbandonedArm, BatchTrial @@ -115,6 +120,8 @@ transform_type_from_json, ) from ax.storage.json_store.encoders import ( + analysis_card_group_to_dict, + analysis_card_to_dict, arm_to_dict, auxiliary_experiment_to_dict, backend_simulator_to_dict, @@ -185,6 +192,8 @@ CORE_ENCODER_REGISTRY: dict[type[Any], Callable[[Any], dict[str, Any]]] = { + AnalysisCard: analysis_card_to_dict, + AnalysisCardGroup: analysis_card_group_to_dict, Arm: arm_to_dict, AuxiliaryExperiment: auxiliary_experiment_to_dict, AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict, @@ -203,21 +212,25 @@ Data: data_to_dict, ExpressionDerivedMetric: metric_to_dict, DerivedParameter: derived_parameter_to_dict, + ErrorAnalysisCard: analysis_card_to_dict, Experiment: experiment_to_dict, FactorialMetric: metric_to_dict, FixedParameter: fixed_parameter_to_dict, GammaPrior: botorch_component_to_dict, + GraphvizAnalysisCard: analysis_card_to_dict, GenerationStep: generation_node_to_dict, GenerationNode: generation_node_to_dict, GenerationStrategy: generation_strategy_to_dict, GeneratorRun: generator_run_to_dict, Hartmann6Metric: metric_to_dict, + HealthcheckAnalysisCard: analysis_card_to_dict, ImprovementGlobalStoppingStrategy: improvement_global_stopping_strategy_to_dict, Interval: botorch_component_to_dict, IsSingleObjective: transition_criterion_to_dict, L2NormMetric: metric_to_dict, LogNormalPrior: botorch_component_to_dict, MapMetric: metric_to_dict, + MarkdownAnalysisCard: analysis_card_to_dict, MaxGenerationParallelism: pausing_criterion_to_dict, MaxTrialsAwaitingData: pausing_criterion_to_dict, Metric: metric_to_dict, @@ -230,6 +243,7 @@ Normalize: botorch_component_to_dict, FilterFeatures: botorch_component_to_dict, PercentileEarlyStoppingStrategy: percentile_early_stopping_strategy_to_dict, + PlotlyAnalysisCard: analysis_card_to_dict, SklearnMetric: metric_to_dict, ChemistryMetric: metric_to_dict, NegativeBraninMetric: metric_to_dict, @@ -283,6 +297,8 @@ # splattable inputs to the resultant class, not just Types with kwarg inits. CORE_DECODER_REGISTRY: TDecoderRegistry = { "AbandonedArm": AbandonedArm, + "AnalysisCard": AnalysisCard, + "AnalysisCardGroup": AnalysisCardGroup, "AndEarlyStoppingStrategy": AndEarlyStoppingStrategy, "AutoTransitionAfterGen": AutoTransitionAfterGen, "AuxiliaryExperiment": AuxiliaryExperiment, @@ -318,12 +334,14 @@ "ExpressionDerivedMetric": ExpressionDerivedMetric, "DerivedParameter": DerivedParameter, "DomainType": DomainType, + "ErrorAnalysisCard": ErrorAnalysisCard, "Experiment": Experiment, "ExperimentStatus": ExperimentStatus, "FactorialMetric": FactorialMetric, "FilterFeatures": FilterFeatures, "FixedParameter": fixed_parameter_from_json, "GammaPrior": GammaPrior, + "GraphvizAnalysisCard": GraphvizAnalysisCard, "GenerationNode": GenerationNode, "GenerationStrategy": GenerationStrategy, "GenerationStep": GenerationStep, @@ -331,6 +349,7 @@ "Generators": Generators, "GeneratorSpec": GeneratorSpec, "Hartmann6Metric": Hartmann6Metric, + "HealthcheckAnalysisCard": HealthcheckAnalysisCard, "HierarchicalSearchSpace": HierarchicalSearchSpace, "ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy, "InputConstructorPurpose": InputConstructorPurpose, @@ -345,6 +364,7 @@ "LogNormalPrior": LogNormalPrior, "MapData": Data, "MapMetric": MapMetric, + "MarkdownAnalysisCard": MarkdownAnalysisCard, "MaxTrials": MinTrials, "MaxGenerationParallelism": MaxGenerationParallelism, "MaxTrialsAwaitingData": MaxTrialsAwaitingData, @@ -382,6 +402,7 @@ "PurePosixPath": pathlib_from_json, "PureWindowsPath": pathlib_from_json, "PercentileEarlyStoppingStrategy": percentile_early_stopping_strategy_from_json, + "PlotlyAnalysisCard": PlotlyAnalysisCard, "RangeParameter": RangeParameter, "ReductionCriterion": ReductionCriterion, "Round": Round, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 80b120825f8..38993cba4b4 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -24,6 +24,10 @@ from ax.adapter.transforms.base import Transform from ax.adapter.transforms.log import Log from ax.adapter.transforms.one_hot import OneHot +from ax.analysis.graphviz.graphviz_analysis import GraphvizAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard +from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.benchmark.methods.sobol import get_sobol_benchmark_method from ax.benchmark.testing.benchmark_stubs import ( get_aggregated_benchmark_result, @@ -33,6 +37,7 @@ get_benchmark_result, get_benchmark_time_varying_metric, ) +from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup, ErrorAnalysisCard from ax.core.auxiliary import AuxiliaryExperimentPurpose from ax.core.data import Data from ax.core.generator_run import GeneratorRun @@ -406,6 +411,90 @@ ("ThresholdEarlyStoppingStrategy", get_threshold_early_stopping_strategy), ("Trial", get_trial), ("WinsorizationConfig", get_winsorization_config), + ( + "AnalysisCard", + lambda: AnalysisCard( + name="TestAnalysis", + title="Test", + subtitle="subtitle", + df=pd.DataFrame({"a": [1, 2]}), + blob="blob_str", + ), + ), + ( + "ErrorAnalysisCard", + lambda: ErrorAnalysisCard( + name="TestError", + title="Error", + subtitle="err subtitle", + df=pd.DataFrame(), + blob="error details", + ), + ), + ( + "PlotlyAnalysisCard", + lambda: PlotlyAnalysisCard( + name="TestPlotly", + title="Plot", + subtitle="plot subtitle", + df=pd.DataFrame({"x": [1]}), + blob="{}", + ), + ), + ( + "MarkdownAnalysisCard", + lambda: MarkdownAnalysisCard( + name="TestMd", + title="MD", + subtitle="md subtitle", + df=pd.DataFrame(), + blob="# Hello", + ), + ), + ( + "HealthcheckAnalysisCard", + lambda: HealthcheckAnalysisCard( + name="TestHC", + title="HC", + subtitle="hc subtitle", + df=pd.DataFrame(), + blob='{"status": 0}', + ), + ), + ( + "GraphvizAnalysisCard", + lambda: GraphvizAnalysisCard( + name="TestGV", + title="GV", + subtitle="gv subtitle", + df=pd.DataFrame(), + blob="digraph {}", + ), + ), + ( + "AnalysisCardGroup", + lambda: AnalysisCardGroup( + name="TestGroup", + title="Group", + subtitle="group subtitle", + children=[ + AnalysisCard( + name="Child", + title="C1", + subtitle="s1", + df=pd.DataFrame({"a": [1]}), + blob="b1", + ), + MarkdownAnalysisCard( + name="Child2", + title="C2", + subtitle="s2", + df=pd.DataFrame(), + blob="# md", + ), + ], + ), + ), ]