@@ -100,7 +100,7 @@ def __init__(
100100 default_data_type : Any = None ,
101101 auxiliary_experiments_by_purpose : None
102102 | (dict [AuxiliaryExperimentPurpose , list [AuxiliaryExperiment ]]) = None ,
103- default_trial_type : str | None = None ,
103+ default_trial_type : str = Keys . DEFAULT_TRIAL_TYPE . value ,
104104 ) -> None :
105105 """Inits Experiment.
106106
@@ -123,6 +123,8 @@ def __init__(
123123 default_data_type: Deprecated and ignored.
124124 auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments
125125 for different purposes (e.g., transfer learning).
126+ default_trial_type: Default trial type for trials on this experiment.
127+ Defaults to Keys.DEFAULT_TRIAL_TYPE.
126128 """
127129 if default_data_type is not None :
128130 warnings .warn (
@@ -150,10 +152,16 @@ def __init__(
150152 self ._properties : dict [str , Any ] = properties or {}
151153
152154 # Initialize trial type to runner mapping
153- self ._default_trial_type = default_trial_type
154- self ._trial_type_to_runner : dict [str | None , Runner | None ] = {
155- default_trial_type : runner
155+ self ._default_trial_type : str = (
156+ default_trial_type or Keys .DEFAULT_TRIAL_TYPE .value
157+ )
158+ self ._trial_type_to_runner : dict [str , Runner | None ] = {
159+ self ._default_trial_type : runner
156160 }
161+
162+ # Maps metric names to their trial types. Every metric must have an entry.
163+ self ._metric_to_trial_type : dict [str , str ] = {}
164+
157165 # Used to keep track of whether any trials on the experiment
158166 # specify a TTL. Since trials need to be checked for their TTL's
159167 # expiration often, having this attribute helps avoid unnecessary
@@ -413,16 +421,46 @@ def runner(self) -> Runner | None:
413421 def runner (self , runner : Runner | None ) -> None :
414422 """Set the default runner and update trial type mapping."""
415423 self ._runner = runner
416- if runner is not None :
417- self ._trial_type_to_runner [self ._default_trial_type ] = runner
418- else :
419- self ._trial_type_to_runner = {None : None }
424+ self ._trial_type_to_runner [self ._default_trial_type ] = runner
420425
421426 @runner .deleter
422427 def runner (self ) -> None :
423428 """Delete the runner."""
424429 self ._runner = None
425- self ._trial_type_to_runner = {None : None }
430+ self ._trial_type_to_runner [self ._default_trial_type ] = None
431+
432+ def add_trial_type (self , trial_type : str , runner : Runner ) -> "Experiment" :
433+ """Add a new trial_type to be supported by this experiment.
434+
435+ Args:
436+ trial_type: The new trial_type to be added.
437+ runner: The default runner for trials of this type.
438+
439+ Returns:
440+ The experiment with the new trial type added.
441+ """
442+ if self .supports_trial_type (trial_type ):
443+ raise ValueError (f"Experiment already contains trial_type `{ trial_type } `" )
444+
445+ self ._trial_type_to_runner [trial_type ] = runner
446+ return self
447+
448+ def update_runner (self , trial_type : str , runner : Runner ) -> "Experiment" :
449+ """Update the default runner for an existing trial_type.
450+
451+ Args:
452+ trial_type: The trial_type to update.
453+ runner: The new runner for trials of this type.
454+
455+ Returns:
456+ The experiment with the updated runner.
457+ """
458+ if not self .supports_trial_type (trial_type ):
459+ raise ValueError (f"Experiment does not contain trial_type `{ trial_type } `" )
460+
461+ self ._trial_type_to_runner [trial_type ] = runner
462+ self ._runner = runner
463+ return self
426464
427465 @property
428466 def parameters (self ) -> dict [str , Parameter ]:
@@ -489,13 +527,25 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
489527 f"`{ Keys .IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF .value } ` "
490528 "property that is set to `True` on this experiment."
491529 )
530+
531+ # Remove old OC metrics from trial type mapping
532+ prev_optimization_config = self ._optimization_config
533+ if prev_optimization_config is not None :
534+ for metric_name in prev_optimization_config .metrics .keys ():
535+ self ._metric_to_trial_type .pop (metric_name , None )
536+
492537 for metric_name in optimization_config .metrics .keys ():
493538 if metric_name in self ._tracking_metrics :
494539 self .remove_tracking_metric (metric_name )
540+
495541 # add metrics from the previous optimization config that are not in the new
496542 # optimization config as tracking metrics
497- prev_optimization_config = self ._optimization_config
498543 self ._optimization_config = optimization_config
544+
545+ # Map new OC metrics to default trial type
546+ for metric_name in optimization_config .metrics .keys ():
547+ self ._metric_to_trial_type [metric_name ] = self ._default_trial_type
548+
499549 if prev_optimization_config is not None :
500550 metrics_to_track = (
501551 set (prev_optimization_config .metrics .keys ())
@@ -505,6 +555,16 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
505555 for metric_name in metrics_to_track :
506556 self .add_tracking_metric (prev_optimization_config .metrics [metric_name ])
507557
558+ # Clean up any stale entries in _metric_to_trial_type that don't correspond
559+ # to actual metrics (can happen when same optimization_config object is
560+ # mutated and reassigned).
561+ current_metric_names = set (self .metrics .keys ())
562+ stale_metric_names = (
563+ set (self ._metric_to_trial_type .keys ()) - current_metric_names
564+ )
565+ for metric_name in stale_metric_names :
566+ self ._metric_to_trial_type .pop (metric_name , None )
567+
508568 @property
509569 def is_moo_problem (self ) -> bool :
510570 """Whether the experiment's optimization config contains multiple objectives."""
@@ -553,12 +613,25 @@ def immutable_search_space_and_opt_config(self) -> bool:
553613 def tracking_metrics (self ) -> list [Metric ]:
554614 return list (self ._tracking_metrics .values ())
555615
556- def add_tracking_metric (self , metric : Metric ) -> Self :
616+ def add_tracking_metric (
617+ self ,
618+ metric : Metric ,
619+ trial_type : str | None = None ,
620+ ) -> Self :
557621 """Add a new metric to the experiment.
558622
559623 Args:
560624 metric: Metric to be added.
625+ trial_type: The trial type for which this metric is used. If not
626+ provided, defaults to the experiment's default trial type.
561627 """
628+ effective_trial_type = (
629+ trial_type if trial_type is not None else self ._default_trial_type
630+ )
631+
632+ if not self .supports_trial_type (effective_trial_type ):
633+ raise ValueError (f"`{ effective_trial_type } ` is not a supported trial type." )
634+
562635 if metric .name in self ._tracking_metrics :
563636 raise ValueError (
564637 f"Metric `{ metric .name } ` already defined on experiment. "
@@ -574,33 +647,73 @@ def add_tracking_metric(self, metric: Metric) -> Self:
574647 )
575648
576649 self ._tracking_metrics [metric .name ] = metric
650+ self ._metric_to_trial_type [metric .name ] = effective_trial_type
577651 return self
578652
579- def add_tracking_metrics (self , metrics : list [Metric ]) -> Self :
653+ def add_tracking_metrics (
654+ self ,
655+ metrics : list [Metric ],
656+ metrics_to_trial_types : dict [str , str ] | None = None ,
657+ ) -> Self :
580658 """Add a list of new metrics to the experiment.
581659
582660 If any of the metrics are already defined on the experiment,
583661 we raise an error and don't add any of them to the experiment
584662
585663 Args:
586664 metrics: Metrics to be added.
665+ metrics_to_trial_types: Optional mapping from metric names to
666+ corresponding trial types. If not provided for a metric,
667+ the experiment's default trial type is used.
587668 """
588- # Before setting any metrics, we validate none are already on
589- # the experiment
669+ metrics_to_trial_types = metrics_to_trial_types or {}
590670 for metric in metrics :
591- self .add_tracking_metric (metric )
671+ self .add_tracking_metric (
672+ metric = metric ,
673+ trial_type = metrics_to_trial_types .get (metric .name ),
674+ )
592675 return self
593676
594- def update_tracking_metric (self , metric : Metric ) -> Self :
677+ def update_tracking_metric (
678+ self ,
679+ metric : Metric ,
680+ trial_type : str | None = None ,
681+ ) -> Self :
595682 """Redefine a metric that already exists on the experiment.
596683
597684 Args:
598685 metric: New metric definition.
686+ trial_type: The trial type for which this metric is used. If not
687+ provided, keeps the existing trial type mapping.
599688 """
600689 if metric .name not in self ._tracking_metrics :
601690 raise ValueError (f"Metric `{ metric .name } ` doesn't exist on experiment." )
602691
692+ # Validate trial type if provided
693+ effective_trial_type = (
694+ trial_type
695+ if trial_type is not None
696+ else self ._metric_to_trial_type .get (metric .name , self ._default_trial_type )
697+ )
698+
699+ # Check that optimization config metrics stay on default trial type
700+ oc = self .optimization_config
701+ oc_metrics = oc .metrics if oc else {}
702+ if (
703+ metric .name in oc_metrics
704+ and effective_trial_type != self ._default_trial_type
705+ ):
706+ raise ValueError (
707+ f"Metric `{ metric .name } ` must remain a "
708+ f"`{ self ._default_trial_type } ` metric because it is part of the "
709+ "optimization_config."
710+ )
711+
712+ if not self .supports_trial_type (effective_trial_type ):
713+ raise ValueError (f"`{ effective_trial_type } ` is not a supported trial type." )
714+
603715 self ._tracking_metrics [metric .name ] = metric
716+ self ._metric_to_trial_type [metric .name ] = effective_trial_type
604717 return self
605718
606719 def remove_tracking_metric (self , metric_name : str ) -> Self :
@@ -613,6 +726,7 @@ def remove_tracking_metric(self, metric_name: str) -> Self:
613726 raise ValueError (f"Metric `{ metric_name } ` doesn't exist on experiment." )
614727
615728 del self ._tracking_metrics [metric_name ]
729+ self ._metric_to_trial_type .pop (metric_name , None )
616730 return self
617731
618732 @property
@@ -852,8 +966,21 @@ def _fetch_trial_data(
852966 ) -> dict [str , MetricFetchResult ]:
853967 trial = self .trials [trial_index ]
854968
969+ # If metrics are not provided, fetch all metrics on the experiment for the
970+ # relevant trial type, or the default trial type as a fallback. Otherwise,
971+ # fetch provided metrics.
972+ if metrics is None :
973+ resolved_metrics = [
974+ metric
975+ for metric in list (self .metrics .values ())
976+ if self ._metric_to_trial_type .get (metric .name , self ._default_trial_type )
977+ == trial .trial_type
978+ ]
979+ else :
980+ resolved_metrics = metrics
981+
855982 trial_data = self ._lookup_or_fetch_trials_results (
856- trials = [trial ], metrics = metrics , ** kwargs
983+ trials = [trial ], metrics = resolved_metrics , ** kwargs
857984 )
858985
859986 if trial_index in trial_data :
@@ -1548,39 +1675,79 @@ def __repr__(self) -> str:
15481675 # overridden in the MultiTypeExperiment class.
15491676
15501677 @property
1551- def default_trial_type (self ) -> str | None :
1552- """Default trial type assigned to trials in this experiment.
1553-
1554- In the base experiment class this is always None. For experiments
1555- with multiple trial types, use the MultiTypeExperiment class.
1556- """
1678+ def default_trial_type (self ) -> str :
1679+ """Default trial type assigned to trials in this experiment."""
15571680 return self ._default_trial_type
15581681
1559- def runner_for_trial_type (self , trial_type : str | None ) -> Runner | None :
1682+ def runner_for_trial_type (self , trial_type : str ) -> Runner | None :
15601683 """The default runner to use for a given trial type.
15611684
15621685 Looks up the appropriate runner for this trial type in the trial_type_to_runner.
15631686 """
1687+ # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1688+ # trial types for deployment.
1689+ if (
1690+ trial_type == Keys .SHORT_RUN or trial_type == Keys .LONG_RUN
1691+ ) and self .supports_trial_type (trial_type = Keys .DEFAULT_TRIAL_TYPE ):
1692+ return self ._trial_type_to_runner [Keys .DEFAULT_TRIAL_TYPE ]
1693+
15641694 if not self .supports_trial_type (trial_type ):
15651695 raise ValueError (f"Trial type `{ trial_type } ` is not supported." )
15661696 if (runner := self ._trial_type_to_runner .get (trial_type )) is None :
15671697 return self .runner # return the default runner
15681698 return runner
15691699
1570- def supports_trial_type (self , trial_type : str | None ) -> bool :
1700+ def supports_trial_type (self , trial_type : str ) -> bool :
15711701 """Whether this experiment allows trials of the given type.
15721702
1573- The base experiment class only supports None. For experiments
1574- with multiple trial types, use the MultiTypeExperiment class.
1703+ Checks if the trial type is registered in the trial_type_to_runner mapping.
15751704 """
1576- return (
1577- trial_type is None
1578- # We temporarily allow "short run" and "long run" trial
1579- # types in single-type experiments during development of
1580- # a new ``GenerationStrategy`` that needs them.
1581- or trial_type == Keys .SHORT_RUN
1582- or trial_type == Keys .LONG_RUN
1583- )
1705+ # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1706+ # trial types for deployment.
1707+ if (
1708+ trial_type == Keys .SHORT_RUN or trial_type == Keys .LONG_RUN
1709+ ) and self .supports_trial_type (trial_type = Keys .DEFAULT_TRIAL_TYPE ):
1710+ return True
1711+
1712+ return trial_type in self ._trial_type_to_runner
1713+
1714+ @property
1715+ def is_multi_type (self ) -> bool :
1716+ """Returns True if this experiment has multiple trial types registered."""
1717+ return len (self ._trial_type_to_runner ) > 1
1718+
1719+ @property
1720+ def metric_to_trial_type (self ) -> dict [str , str ]:
1721+ """Read-only mapping of metric names to trial types."""
1722+ return self ._metric_to_trial_type .copy ()
1723+
1724+ def metrics_for_trial_type (self , trial_type : str ) -> list [Metric ]:
1725+ """Returns metrics associated with a specific trial type.
1726+
1727+ Args:
1728+ trial_type: The trial type to get metrics for.
1729+
1730+ Returns:
1731+ List of metrics associated with the given trial type.
1732+ """
1733+ # Special case for LONG_ and SHORT_RUN trial types, which we treat as "default"
1734+ # trial types for deployment.
1735+ if (
1736+ trial_type == Keys .SHORT_RUN or trial_type == Keys .LONG_RUN
1737+ ) and self .supports_trial_type (trial_type = Keys .DEFAULT_TRIAL_TYPE ):
1738+ return [
1739+ self .metrics [metric_name ]
1740+ for metric_name , metric_trial_type in self ._metric_to_trial_type .items ()
1741+ if metric_trial_type == Keys .DEFAULT_TRIAL_TYPE
1742+ ]
1743+
1744+ if not self .supports_trial_type (trial_type ):
1745+ raise ValueError (f"Trial type `{ trial_type } ` is not supported." )
1746+ return [
1747+ self .metrics [metric_name ]
1748+ for metric_name , metric_trial_type in self ._metric_to_trial_type .items ()
1749+ if metric_trial_type == trial_type
1750+ ]
15841751
15851752 def attach_trial (
15861753 self ,
0 commit comments