1414from ax .analysis .plotly .color_constants import BOTORCH_COLOR_SCALE
1515from ax .analysis .plotly .plotly_analysis import create_plotly_analysis_card
1616from ax .analysis .plotly .utils import (
17+ generator_run_key_to_color ,
1718 get_arm_tooltip ,
1819 get_trial_statuses_with_fallback ,
1920 get_trial_trace_name ,
3839from ax .core .arm import Arm
3940from ax .core .data import sort_by_trial_index_and_arm_name
4041from ax .core .experiment import Experiment
42+ from ax .core .generator_run import GeneratorRun
4143from ax .core .trial_status import TrialStatus
4244from ax .generation_strategy .generation_strategy import GenerationStrategy
4345from plotly import graph_objects as go
@@ -76,6 +78,7 @@ def __init__(
7678 trial_index : int | None = None ,
7779 trial_statuses : Sequence [TrialStatus ] | None = None ,
7880 additional_arms : Sequence [Arm ] | None = None ,
81+ generator_runs : Mapping [str , GeneratorRun ] | None = None ,
7982 label : str | None = None ,
8083 ) -> None :
8184 """
@@ -95,6 +98,10 @@ def __init__(
9598 additional_arms: If present, include these arms in the plot in addition to
9699 the arms in the experiment. These arms will be marked as belonging to a
97100 trial with index -1.
101+ generator_runs: If present, a mapping from name to GeneratorRun. Each
102+ GeneratorRun's arms will be plotted as a separate group with distinct
103+ colors and legend entries. Unnamed arms will be labeled as
104+ ``{key}_0``, ``{key}_1``, etc.
98105 label: A label to use in the plot in place of the metric name.
99106 """
100107
@@ -108,6 +115,7 @@ def __init__(
108115 )
109116 )
110117 self .additional_arms = additional_arms
118+ self .generator_runs = generator_runs
111119 self .label = label
112120
113121 @override
@@ -186,6 +194,7 @@ def compute(
186194 trial_index = self .trial_index ,
187195 trial_statuses = self .trial_statuses ,
188196 additional_arms = self .additional_arms ,
197+ generator_runs = self .generator_runs ,
189198 relativize = self .relativize ,
190199 )
191200
@@ -255,6 +264,7 @@ def compute_arm_effects_adhoc(
255264 trial_index : int | None = None ,
256265 trial_statuses : Sequence [TrialStatus ] | None = None ,
257266 additional_arms : Sequence [Arm ] | None = None ,
267+ generator_runs : Mapping [str , GeneratorRun ] | None = None ,
258268 labels : Mapping [str , str ] | None = None ,
259269) -> AnalysisCardGroup :
260270 """
@@ -299,6 +309,7 @@ def compute_arm_effects_adhoc(
299309 trial_index = trial_index ,
300310 trial_statuses = trial_statuses ,
301311 additional_arms = additional_arms ,
312+ generator_runs = generator_runs ,
302313 label = labels .get (metric_name ) if labels is not None else None ,
303314 ).compute_or_error_card (
304315 experiment = experiment ,
@@ -314,6 +325,58 @@ def compute_arm_effects_adhoc(
314325 )
315326
316327
328+ def _build_scatter (
329+ trial_df : pd .DataFrame ,
330+ metric_name : str ,
331+ is_relative : bool ,
332+ status_quo_arm_name : str | None ,
333+ color : str ,
334+ ci_color : str ,
335+ trace_name : str ,
336+ showlegend : bool ,
337+ legendgroup : str | None ,
338+ ) -> tuple [go .Scatter , list [str ], list [str ]] | None :
339+ """Build a scatter trace for a group of arms.
340+
341+ Returns (scatter, arm_order_entries, arm_label_entries), or None if no
342+ valid data points exist.
343+ """
344+ xy_df = trial_df [~ trial_df [f"{ metric_name } _mean" ].isna ()]
345+ if xy_df .empty :
346+ return None
347+ if is_relative and status_quo_arm_name is not None :
348+ xy_df = xy_df [xy_df ["arm_name" ] != status_quo_arm_name ]
349+ if xy_df .empty :
350+ return None
351+
352+ if not trial_df [f"{ metric_name } _sem" ].isna ().all ():
353+ error_y = {
354+ "type" : "data" ,
355+ "array" : Z_SCORE_95_CI * xy_df [f"{ metric_name } _sem" ],
356+ "color" : ci_color ,
357+ }
358+ else :
359+ error_y = None
360+
361+ text = xy_df .apply (
362+ lambda row : get_arm_tooltip (row = row , metric_names = [metric_name ]), axis = 1
363+ )
364+
365+ scatter = go .Scatter (
366+ x = xy_df ["x_key_order" ],
367+ y = xy_df [f"{ metric_name } _mean" ],
368+ error_y = error_y ,
369+ mode = "markers" ,
370+ marker = {"color" : color },
371+ name = trace_name ,
372+ showlegend = showlegend ,
373+ hoverinfo = "text" ,
374+ text = text ,
375+ legendgroup = legendgroup ,
376+ )
377+ return scatter , xy_df ["x_key_order" ].to_list (), xy_df ["arm_name" ].to_list ()
378+
379+
317380def _prepare_figure (
318381 df : pd .DataFrame ,
319382 metric_name : str ,
@@ -350,76 +413,94 @@ def _prepare_figure(
350413 num_non_candidate_trials = 0
351414 candidate_trial_marker = None
352415
416+ # --- Trial loop ---
353417 for trial_index in trial_indices :
354- trial_df = df [df [ "trial_index" ] == trial_index ]
355- xy_df = trial_df [ ~ trial_df [ f" { metric_name } _mean " ].isna ()]
356- # Skip trials with no valid data points as they will not end up in the plot
357- if xy_df .empty :
418+ trial_df = df [
419+ ( df [ "trial_index" ] == trial_index ) & ( df [ "generator_run_key " ].isna ())
420+ ]
421+ if trial_df .empty :
358422 continue
359- if is_relative and status_quo_arm_name is not None :
360- # Exclude status quo arms from relativized plots, since arms are relative
361- # with respect to the status quo.
362- xy_df = xy_df [xy_df ["arm_name" ] != status_quo_arm_name ]
363-
364- arm_order = arm_order + xy_df ["x_key_order" ].to_list ()
365- arm_label = arm_label + xy_df ["arm_name" ].to_list ()
366- if not trial_df [f"{ metric_name } _sem" ].isna ().all ():
367- error_y = {
368- "type" : "data" ,
369- "array" : Z_SCORE_95_CI * xy_df [f"{ metric_name } _sem" ],
370- "color" : trial_index_to_color (
371- trial_df = trial_df ,
372- trials_list = trials_list ,
373- trial_index = trial_index ,
374- transparent = True ,
375- ),
376- }
377- else :
378- error_y = None
379-
380- marker = {
381- "color" : trial_index_to_color (
382- trial_df = trial_df ,
383- trials_list = trials_list ,
384- trial_index = trial_index ,
385- transparent = False ,
386- ),
387- }
388-
389- if trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name :
423+ color = trial_index_to_color (
424+ trial_df = trial_df ,
425+ trials_list = trials_list ,
426+ trial_index = trial_index ,
427+ transparent = False ,
428+ )
429+ ci_color = trial_index_to_color (
430+ trial_df = trial_df ,
431+ trials_list = trials_list ,
432+ trial_index = trial_index ,
433+ transparent = True ,
434+ )
435+ is_candidate = (
436+ not trial_df .empty
437+ and trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name
438+ )
439+ result = _build_scatter (
440+ trial_df = trial_df ,
441+ metric_name = metric_name ,
442+ is_relative = is_relative ,
443+ status_quo_arm_name = status_quo_arm_name ,
444+ color = color ,
445+ ci_color = ci_color ,
446+ trace_name = get_trial_trace_name (trial_index = trial_index ),
447+ showlegend = False , # Will be set after determining use_colorscale
448+ legendgroup = "candidate_trials" if is_candidate else None ,
449+ )
450+ if result is None :
451+ continue
452+ scatter , order , labels = result
453+ scatters .append (scatter )
454+ scatter_trial_indices .append (trial_index )
455+ arm_order += order
456+ arm_label += labels
457+ if is_candidate :
390458 num_candidate_trials += 1
391- candidate_trial_marker = marker
459+ candidate_trial_marker = { "color" : color }
392460 else :
393461 num_non_candidate_trials += 1
394462
395- text = xy_df .apply (
396- lambda row : get_arm_tooltip (row = row , metric_names = [metric_name ]), axis = 1
463+ # --- Generator run loop ---
464+ unique_gr_keys = df ["generator_run_key" ].dropna ().unique ().tolist ()
465+ generator_run_scatters : list [go .Scatter ] = []
466+ for gr_key in unique_gr_keys :
467+ gr_df = df [df ["generator_run_key" ] == gr_key ]
468+ color = generator_run_key_to_color (
469+ generator_run_key = gr_key ,
470+ all_generator_run_keys = unique_gr_keys ,
471+ transparent = False ,
397472 )
398-
399- scatters .append (
400- go .Scatter (
401- x = xy_df ["x_key_order" ],
402- y = xy_df [f"{ metric_name } _mean" ],
403- error_y = error_y ,
404- mode = "markers" ,
405- marker = marker ,
406- name = get_trial_trace_name (trial_index = trial_index ),
407- showlegend = False , # Will be set after determining use_colorscale
408- hoverinfo = "text" ,
409- text = text ,
410- legendgroup = "candidate_trials"
411- if trial_df ["trial_status" ].iloc [0 ] == TrialStatus .CANDIDATE .name
412- else None ,
413- )
473+ ci_color = generator_run_key_to_color (
474+ generator_run_key = gr_key ,
475+ all_generator_run_keys = unique_gr_keys ,
476+ transparent = True ,
414477 )
415- scatter_trial_indices .append (trial_index )
478+ result = _build_scatter (
479+ trial_df = gr_df ,
480+ metric_name = metric_name ,
481+ is_relative = is_relative ,
482+ status_quo_arm_name = status_quo_arm_name ,
483+ color = color ,
484+ ci_color = ci_color ,
485+ trace_name = gr_key ,
486+ showlegend = True ,
487+ legendgroup = None ,
488+ )
489+ if result is None :
490+ continue
491+ scatter , order , labels = result
492+ generator_run_scatters .append (scatter )
493+ arm_order += order
494+ arm_label += labels
416495
417496 # Determine use_colorscale based on actual included trials
418497 use_colorscale = num_non_candidate_trials > 10
419498
420499 # Update markers and legend settings based on use_colorscale
421500 for scatter , trial_index in zip (scatters , scatter_trial_indices ):
422- trial_df = df [df ["trial_index" ] == trial_index ]
501+ trial_df = df [
502+ (df ["trial_index" ] == trial_index ) & (df ["generator_run_key" ].isna ())
503+ ]
423504
424505 if use_colorscale :
425506 # Add colorscale settings to marker
@@ -445,6 +526,9 @@ def _prepare_figure(
445526 trial_df ["trial_status" ].iloc [0 ] != TrialStatus .CANDIDATE .name
446527 )
447528
529+ # Append generator run scatters (not subject to colorscale)
530+ scatters .extend (generator_run_scatters )
531+
448532 # get the max length of x-ticker (arm name) to set the xaxis label and
449533 # legend position
450534 # This assumes the x-tickers are rotated 90 degrees (vertical) so legend
0 commit comments