14
14
from collections import defaultdict , OrderedDict
15
15
from collections .abc import Hashable , Iterable , Mapping
16
16
from datetime import datetime
17
- from functools import partial , reduce
17
+ from functools import partial
18
18
19
19
from typing import Any , cast
20
20
@@ -355,18 +355,17 @@ def arms_by_signature_for_deduplication(self) -> dict[str, Arm]:
355
355
return arms_dict
356
356
357
357
@property
358
- def sum_trial_sizes (self ) -> int :
359
- """Sum of numbers of arms attached to each trial in this experiment."""
360
- return reduce (lambda a , b : a + len (b .arms_by_name ), self ._trials .values (), 0 )
358
+ def metrics (self ) -> dict [str , Metric ]:
359
+ """The metrics attached to the experiment."""
360
+ optimization_config_metrics : dict [str , Metric ] = {}
361
+ if self .optimization_config is not None :
362
+ optimization_config_metrics = self .optimization_config .metrics
363
+ return {** self ._tracking_metrics , ** optimization_config_metrics }
361
364
362
365
@property
363
366
def num_abandoned_arms (self ) -> int :
364
367
"""How many arms attached to this experiment are abandoned."""
365
- abandoned = set ()
366
- for trial in self .trials .values ():
367
- for x in trial .abandoned_arms :
368
- abandoned .add (x )
369
- return len (abandoned )
368
+ return len ({aa for t in self .trials .values () for aa in t .abandoned_arms })
370
369
371
370
@property
372
371
def optimization_config (self ) -> OptimizationConfig | None :
@@ -495,14 +494,6 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
495
494
del self ._tracking_metrics [metric_name ]
496
495
return self
497
496
498
- @property
499
- def metrics (self ) -> dict [str , Metric ]:
500
- """The metrics attached to the experiment."""
501
- optimization_config_metrics : dict [str , Metric ] = {}
502
- if self .optimization_config is not None :
503
- optimization_config_metrics = self .optimization_config .metrics
504
- return {** self ._tracking_metrics , ** optimization_config_metrics }
505
-
506
497
def _metrics_by_class (
507
498
self , metrics : list [Metric ] | None = None
508
499
) -> dict [type [Metric ], list [Metric ]]:
@@ -518,6 +509,7 @@ def _metrics_by_class(
518
509
519
510
def fetch_data_results (
520
511
self ,
512
+ trial_indices : Iterable [int ] | None = None ,
521
513
metrics : list [Metric ] | None = None ,
522
514
combine_with_last_data : bool = False ,
523
515
overwrite_existing_data : bool = False ,
@@ -546,43 +538,9 @@ def fetch_data_results(
546
538
"""
547
539
548
540
return self ._lookup_or_fetch_trials_results (
549
- trials = list (self .trials .values ()),
550
- metrics = metrics ,
551
- combine_with_last_data = combine_with_last_data ,
552
- overwrite_existing_data = overwrite_existing_data ,
553
- ** kwargs ,
554
- )
555
-
556
- def fetch_trials_data_results (
557
- self ,
558
- trial_indices : Iterable [int ],
559
- metrics : list [Metric ] | None = None ,
560
- combine_with_last_data : bool = False ,
561
- overwrite_existing_data : bool = False ,
562
- ** kwargs : Any ,
563
- ) -> dict [int , dict [str , MetricFetchResult ]]:
564
- """Fetches data for specific trials on the experiment.
565
-
566
- If a metric fetch fails, the Exception will be captured in the
567
- MetricFetchResult along with a message.
568
-
569
- NOTE: For metrics that are not available while trial is running, the data
570
- may be retrieved from cache on the experiment. Data is cached on the experiment
571
- via calls to `experiment.attach_data` and whether a given metric class is
572
- available while trial is running is determined by the boolean returned from its
573
- `is_available_while_running` class method.
574
-
575
- Args:
576
- trial_indices: Indices of trials, for which to fetch data.
577
- metrics: If provided, fetch data for these metrics instead of the ones
578
- defined on the experiment.
579
- kwargs: keyword args to pass to underlying metrics' fetch data functions.
580
-
581
- Returns:
582
- A nested Dictionary from trial_index => metric_name => result
583
- """
584
- return self ._lookup_or_fetch_trials_results (
585
- trials = self .get_trials_by_indices (trial_indices = trial_indices ),
541
+ trials = self .get_trials_by_indices (trial_indices = trial_indices )
542
+ if trial_indices is not None
543
+ else list (self .trials .values ()),
586
544
metrics = metrics ,
587
545
combine_with_last_data = combine_with_last_data ,
588
546
overwrite_existing_data = overwrite_existing_data ,
@@ -591,6 +549,7 @@ def fetch_trials_data_results(
591
549
592
550
def fetch_data (
593
551
self ,
552
+ trial_indices : Iterable [int ] | None = None ,
594
553
metrics : list [Metric ] | None = None ,
595
554
combine_with_last_data : bool = False ,
596
555
overwrite_existing_data : bool = False ,
@@ -618,63 +577,15 @@ def fetch_data(
618
577
Data for the experiment.
619
578
"""
620
579
621
- results = self ._lookup_or_fetch_trials_results (
622
- trials = list ( self . trials . values ()) ,
580
+ results = self .fetch_data_results (
581
+ trial_indices = trial_indices ,
623
582
metrics = metrics ,
624
583
combine_with_last_data = combine_with_last_data ,
625
584
overwrite_existing_data = overwrite_existing_data ,
626
585
** kwargs ,
627
586
)
628
-
629
- base_metric_cls = (
630
- MapMetric if self .default_data_constructor == MapData else Metric
631
- )
632
-
633
- return base_metric_cls ._unwrap_experiment_data_multi (
634
- results = results ,
635
- )
636
-
637
- def fetch_trials_data (
638
- self ,
639
- trial_indices : Iterable [int ],
640
- metrics : list [Metric ] | None = None ,
641
- combine_with_last_data : bool = False ,
642
- overwrite_existing_data : bool = False ,
643
- ** kwargs : Any ,
644
- ) -> Data :
645
- """Fetches data for specific trials on the experiment.
646
-
647
- NOTE: For metrics that are not available while trial is running, the data
648
- may be retrieved from cache on the experiment. Data is cached on the experiment
649
- via calls to `experiment.attach_data` and whetner a given metric class is
650
- available while trial is running is determined by the boolean returned from its
651
- `is_available_while_running` class method.
652
-
653
- NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and
654
- lose rows) if Experiment.default_data_type is misconfigured!
655
-
656
- Args:
657
- trial_indices: Indices of trials, for which to fetch data.
658
- metrics: If provided, fetch data for these metrics instead of the ones
659
- defined on the experiment.
660
- kwargs: Keyword args to pass to underlying metrics' fetch data functions.
661
-
662
- Returns:
663
- Data for the specific trials on the experiment.
664
- """
665
-
666
- results = self ._lookup_or_fetch_trials_results (
667
- trials = self .get_trials_by_indices (trial_indices = trial_indices ),
668
- metrics = metrics ,
669
- combine_with_last_data = combine_with_last_data ,
670
- overwrite_existing_data = overwrite_existing_data ,
671
- ** kwargs ,
672
- )
673
-
674
- base_metric_cls = (
675
- MapMetric if self .default_data_constructor == MapData else Metric
676
- )
677
- return base_metric_cls ._unwrap_experiment_data_multi (
587
+ use_map_data = self .default_data_constructor == MapData
588
+ return (MapMetric if use_map_data else Metric )._unwrap_experiment_data_multi (
678
589
results = results ,
679
590
)
680
591
0 commit comments