diff --git a/README.md b/README.md index e546e8356d..d0b96fb6bb 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ SMAC offers a robust and flexible framework for Bayesian Optimization to support hyperparameter configurations for their (Machine Learning) algorithms, datasets and applications at hand. The main core consists of Bayesian Optimization in combination with an aggressive racing mechanism to efficiently decide which of two configurations performs better. -SMAC3 is written in Python3 and continuously tested with Python 3.8, 3.9, and 3.10. Its Random +SMAC3 is written in Python3 and continuously tested with Python 3.8, 3.9, and 3.10 (and works with newer python versions). Its Random Forest is written in C++. In further texts, SMAC is representatively mentioned for SMAC3. -> [Documentation](https://automl.github.io/SMAC3) +> [Documentation](https://automl.github.io/SMAC3/latest/) > [Roadmap](https://github.com/orgs/automl/projects/5/views/2) @@ -36,7 +36,7 @@ We hope you enjoy this new user experience as much as we do. 🚀 ## Installation -This instruction is for the installation on a Linux system, for Windows and Mac and further information see the [documentation](https://automl.github.io/SMAC3/main/1_installation.html). +This instruction is for the installation on a Linux system, for Windows and Mac and further information see the [documentation](https://automl.github.io/SMAC3/latest/1_installation/). Create a new environment with python 3.10 and make sure swig is installed either on your system or inside the environment. We demonstrate the installation via anaconda in the following: @@ -94,7 +94,7 @@ smac = HyperparameterOptimizationFacade(scenario, train) incumbent = smac.optimize() ``` -More examples can be found in the [documentation](https://automl.github.io/SMAC3/main/examples/). +More examples can be found in the [documentation](https://automl.github.io/SMAC3/latest/examples/1%20Basics/1_quadratic_function/). ## Visualization via DeepCAVE @@ -123,7 +123,7 @@ For all other inquiries, please write an email to smac[at]ai[dot]uni[dash]hannov ## Miscellaneous SMAC3 is developed by the [AutoML Groups of the Universities of Hannover and -Freiburg](http://www.automl.org/). +Freiburg](http://www.automl.org/). It is a featured optimizer on [AutoML Space](https://automl.space/automl-tools/). If you have found a bug, please report to [issues](https://github.com/automl/SMAC3/issues). Moreover, we are appreciating any kind of help. Find our guidelines for contributing to this package @@ -144,4 +144,4 @@ If you use SMAC in one of your research projects, please cite our } ``` -Copyright (C) 2016-2022 [AutoML Group](http://www.automl.org). +Copyright (c) 2025, [Leibniz University Hannover - Institute of AI](https://www.ai.uni-hannover.de/) \ No newline at end of file diff --git a/examples/2_multi_fidelity/1_mlp_epochs.py b/examples/2_multi_fidelity/1_mlp_epochs.py index 48c027d3bd..b4248ac256 100644 --- a/examples/2_multi_fidelity/1_mlp_epochs.py +++ b/examples/2_multi_fidelity/1_mlp_epochs.py @@ -80,7 +80,7 @@ def configspace(self) -> ConfigurationSpace: return cs - def train(self, config: Configuration, seed: int = 0, budget: int = 25) -> float: + def train(self, config: Configuration, seed: int = 0, instance: str = "0", budget: int = 25) -> dict[str, float]: # For deactivated parameters (by virtue of the conditions), # the configuration stores None-values. # This is not accepted by the MLP, so we replace them with placeholder values. @@ -106,7 +106,7 @@ def train(self, config: Configuration, seed: int = 0, budget: int = 25) -> float cv = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True) # to make CV splits consistent score = cross_val_score(classifier, dataset.data, dataset.target, cv=cv, error_score="raise") - return 1 - np.mean(score) + return {"accuracy": 1 - np.mean(score)} def plot_trajectory(facades: list[AbstractFacade]) -> None: @@ -147,9 +147,11 @@ def plot_trajectory(facades: list[AbstractFacade]) -> None: mlp.configspace, walltime_limit=60, # After 60 seconds, we stop the hyperparameter optimization n_trials=500, # Evaluate max 500 different trials - min_budget=1, # Train the MLP using a hyperparameter configuration for at least 5 epochs - max_budget=25, # Train the MLP using a hyperparameter configuration for at most 25 epochs - n_workers=8, + instances=[str(i) for i in range(10)], + objectives="accuracy", + # min_budget=1, # Train the MLP using a hyperparameter configuration for at least 5 epochs + # max_budget=25, # Train the MLP using a hyperparameter configuration for at most 25 epochs + n_workers=4, ) # We want to run five random configurations before starting the optimization. diff --git a/smac/acquisition/function/abstract_acquisition_function.py b/smac/acquisition/function/abstract_acquisition_function.py index 519f5b3d0f..42b94675e6 100644 --- a/smac/acquisition/function/abstract_acquisition_function.py +++ b/smac/acquisition/function/abstract_acquisition_function.py @@ -50,7 +50,7 @@ def update(self, model: AbstractModel, **kwargs: Any) -> None: This method will be called after fitting the model, but before maximizing the acquisition function. As an examples, EI uses it to update the current fmin. The default implementation only updates the - attributes of the acqusition function which are already present. + attributes of the acquisition function which are already present. Calls `_update` to update the acquisition function attributes. @@ -65,7 +65,7 @@ def update(self, model: AbstractModel, **kwargs: Any) -> None: self._update(**kwargs) def _update(self, **kwargs: Any) -> None: - """Update acsquisition function attributes + """Update acquisition function attributes Might be different for each child class. """ diff --git a/smac/acquisition/function/expected_hypervolume.py b/smac/acquisition/function/expected_hypervolume.py new file mode 100644 index 0000000000..8cdd30eff9 --- /dev/null +++ b/smac/acquisition/function/expected_hypervolume.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +from typing import Any, Iterator + +from ConfigSpace import Configuration + +import pygmo +import numpy as np + +from smac.intensifier.abstract_intensifier import AbstractIntensifier +from smac.runhistory import TrialInfo, RunHistory +from smac.runhistory.encoder import AbstractRunHistoryEncoder +from smac.runhistory.dataclasses import InstanceSeedBudgetKey +from smac.scenario import Scenario +from smac.utils.configspace import get_config_hash +from smac.utils.logging import get_logger +from smac.acquisition.function.abstract_acquisition_function import AbstractAcquisitionFunction +from smac.model.abstract_model import AbstractModel +from smac.utils.multi_objective import normalize_costs + +# import torch +# from botorch.acquisition.multi_objective import ExpectedHypervolumeImprovement +# from botorch.models.model import Model +# from botorch.utils.multi_objective.box_decompositions.non_dominated import ( +# NondominatedPartitioning, +# ) + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + +# class _PosteriorProxy(object): +# def __init__(self) -> None: +# self.mean: Tensor = [] +# self.variance: Tensor = [] + +# class _ModelProxy(Model, ABC): +# def __init__(self, model: AbstractModel, objective_bounds: list[tuple[float, float]]): +# super(_ModelProxy).__init__() +# self.model = model +# self._objective_bounds = objective_bounds +# +# def posterior(self, X: Tensor, **kwargs: Any) -> _PosteriorProxy: +# """Docstring +# X: A `b x q x d`-dim Tensor, where `d` is the dimension of the +# feature space, `q` is the number of points considered jointly, +# and `b` is the batch dimension. +# +# +# A `Posterior` object, representing a batch of `b` joint distributions +# over `q` points and `m` outputs each. +# """ +# assert X.shape[1] == 1 +# X = X.reshape([X.shape[0], -1]).numpy() # 3D -> 2D +# +# # predict +# # start_time = time.time() +# # print(f"Start predicting ") +# mean, var_ = self.model.predict_marginalized(X) +# normalized_mean = np.array([normalize_costs(m, self._objective_bounds) for m in mean]) +# scale = normalized_mean / mean +# var_ *= scale # Scale variance accordingly +# mean = normalized_mean +# # print(f"Done in {time.time() - start_time}s") +# post = _PosteriorProxy() +# post.mean = torch.asarray(mean).reshape(X.shape[0], 1, -1) # 2D -> 3D +# post.variance = torch.asarray(var_).reshape(X.shape[0], 1, -1) # 2D -> 3D +# +# return post + +class AbstractHVI(AbstractAcquisitionFunction): + def __init__(self): + """Computes for a given x the predicted hypervolume improvement as + acquisition value. + """ + super(AbstractHVI, self).__init__() + self._required_updates = ("model",) + self._reference_point = None + self._objective_bounds = None + + self._runhistory: RunHistory | None = None + self._runhistory_encoder: AbstractRunHistoryEncoder | None = None + + @property + def runhistory(self) -> RunHistory: + return self._runhistory + + @runhistory.setter + def runhistory(self, runhistory: RunHistory): + self._runhistory = runhistory + + @property + def runhistory_encoder(self) -> AbstractRunHistoryEncoder: + return self._runhistory_encoder + + @runhistory_encoder.setter + def runhistory_encoder(self, runhistory_encoder: AbstractRunHistoryEncoder): + self._runhistory_encoder = runhistory_encoder + + @property + def name(self) -> str: + return "Abstract Hypervolume Improvement" + + def _update(self, **kwargs: Any) -> None: + super(AbstractHVI, self)._update(**kwargs) + + incumbents: list[Configuration] = kwargs.get("incumbents", None) + if incumbents is None: + raise ValueError(f"Incumbents are not passed properly.") + if len(incumbents) == 0: + raise ValueError(f"No incumbents here. Did the intensifier properly " + "update the incumbents in the runhistory?") + + objective_bounds = np.array(self.runhistory.objective_bounds) + self._objective_bounds = self.runhistory_encoder.transform_response_values( + objective_bounds) + self._reference_point = [1.1] * len(self._objective_bounds) + + def get_hypervolume(self, points: np.ndarray = None, reference_point: list = None) -> float: + """ + Compute the hypervolume + + Parameters + ---------- + points : np.ndarray + A 2d numpy array. 1st dimension is an entity and the 2nd dimension are the costs + reference_point : list + + Return + ------ + hypervolume: float + """ + # Normalize the objectives here to give equal attention to the objectives when computing the HV + points = [normalize_costs(p, self._objective_bounds) for p in points] + + hv = pygmo.hypervolume(points) + # if reference_point is None: + # self._reference_point = hv.refpoint(offset=1) + return hv.compute(self._reference_point) + + def _compute(self, X: np.ndarray) -> np.ndarray: + """Computes the PHVI values and its derivatives. + + Parameters + ---------- + X: np.ndarray(N, D), The input points where the acquisition function + should be evaluated. The dimensionality of X is (N, D), with N as + the number of points to evaluate at and D is the number of + dimensions of one X. + + Returns + ------- + np.ndarray(N,1) + Expected HV Improvement of X + """ + if len(X.shape) == 1: + X = X[:, np.newaxis] + + # TODO non-dominated sorting of costs. Compute EHVI only until the EHVI is not expected to improve anymore. + # Option 1: Supplement missing instances of population with acq. function to get predicted performance over + # all instances. Idea is this prevents optimizing for the initial instances which get it stuck in local optima + # Option 2: Only on instances of population + # Option 3: EVHI per instance and aggregate afterwards + mean, var_ = self.model.predict_marginalized(X) #Expected to be not normalized + + + phvi = np.zeros(len(X)) + for i, indiv in enumerate(mean): + points = list(self.population_costs) + [indiv] + hv = self.get_hypervolume(points) + phvi[i] = hv - self.population_hv + + # if len(X) == 10000: + # for op in ["max", "min", "mean", "median"]: + # val = getattr(np, op)(phvi) + # print(f"{op:6} - {val}") + # time.sleep(1.5) + + return phvi.reshape(-1, 1) + +# class EHVI(AbstractHVI): +# def __init__(self): +# super(EHVI, self).__init__() +# self._ehvi: ExpectedHypervolumeImprovement | None = None +# +# @property +# def name(self) -> str: +# return "Expected Hypervolume Improvement" +# +# def _update(self, **kwargs: Any) -> None: +# super(EHVI, self)._update(**kwargs) +# incumbents: list[Configuration] = kwargs.get("incumbents", None) +# +# # Update EHVI +# # Prediction all +# population_configs = incumbents +# population_X = np.array([config.get_array() for config in population_configs]) +# population_costs, _ = self.model.predict_marginalized(population_X) +# # Normalize the objectives here to give equal attention to the objectives when computing the HV +# population_costs = [normalize_costs(p, self._objective_bounds) for p in population_costs] +# +# # BOtorch EHVI implementation +# bomodel = _ModelProxy(self.model, self._objective_bounds) +# # ref_point = pygmo.hypervolume(population_costs).refpoint( +# # offset=1 +# # ) # TODO get proper reference points from user/cutoffs +# ref_point = [1.1] * len(self._objective_bounds) +# # ref_point = torch.asarray(ref_point) +# # TODO partition from all runs instead of only population? +# # TODO NondominatedPartitioning and ExpectedHypervolumeImprovement seem no too difficult to implement natively +# # TODO pass along RNG +# # Transfrom the objective space to cells based on the population +# partitioning = NondominatedPartitioning(torch.asarray(ref_point), torch.asarray(population_costs)) +# self._ehvi = ExpectedHypervolumeImprovement(bomodel, ref_point, partitioning) +# +# def _compute(self, X: np.ndarray) -> np.ndarray: +# """Computes the EHVI values and its derivatives. +# +# Parameters +# ---------- +# X: np.ndarray(N, D), The input points where the acquisition function +# should be evaluated. The dimensionality of X is (N, D), with N as +# the number of points to evaluate at and D is the number of +# dimensions of one X. +# +# Returns +# ------- +# np.ndarray(N,1) +# Expected HV Improvement of X +# """ +# if self._ehvi is None: +# raise ValueError(f"The expected hypervolume improvement is not defined yet. Call self.update.") +# +# if len(X.shape) == 1: +# X = X[:, np.newaxis] +# +# # m, var_ = self.model.predict_marginalized_over_instances(X) +# # Find a way to propagate the variance into the HV +# boX = torch.asarray(X).reshape(X.shape[0], 1, -1) # 2D -> #3D +# improvements = self._ehvi(boX).numpy().reshape(-1, 1) # TODO here are the expected hv improvements computed. +# return improvements +# +# # TODO non-dominated sorting of costs. Compute EHVI only until the EHVI is not expected to improve anymore. +# # Option 1: Supplement missing instances of population with acq. function to get predicted performance over +# # all instances. Idea is this prevents optimizing for the initial instances which get it stuck in local optima +# # Option 2: Only on instances of population +# # Option 3: EVHI per instance and aggregate afterwards +# # ehvi = np.zeros(len(X)) +# # for i, indiv in enumerate(m): +# # ehvi[i] = self.get_hypervolume(population_costs + [indiv]) - population_hv +# # +# # return ehvi.reshape(-1, 1) + +class PHVI(AbstractHVI): + + def __init__(self): + super(PHVI, self).__init__() + self.population_hv = None + self.population_costs = None + + @property + def name(self) -> str: + return "Predicted Hypervolume Improvement" + + def _update(self, **kwargs: Any) -> None: + super(PHVI, self)._update(**kwargs) + incumbents: list[Configuration] = kwargs.get("incumbents", None) + + # Update PHVI + # Prediction all + population_configs = incumbents + population_X = np.array([config.get_array() for config in population_configs]) + population_costs, _ = self.model.predict_marginalized(population_X) + + # Compute HV + population_hv = self.get_hypervolume(population_costs) + + self.population_costs = population_costs + self.population_hv = population_hv + + logger.info(f"New population HV: {population_hv}") + + def get_hypervolume(self, points: np.ndarray = None, reference_point: list = None) -> float: + """ + Compute the hypervolume + + Parameters + ---------- + points : np.ndarray + A 2d numpy array. 1st dimension is an entity and the 2nd dimension are the costs + reference_point : list + + Return + ------ + hypervolume: float + """ + # Normalize the objectives here to give equal attention to the objectives when computing the HV + points = [normalize_costs(p, self._objective_bounds) for p in points] + hv = pygmo.hypervolume(points) + return hv.compute(self._reference_point) + + def _compute(self, X: np.ndarray) -> np.ndarray: + """Computes the PHVI values and its derivatives. + + Parameters + ---------- + X: np.ndarray(N, D), The input points where the acquisition function + should be evaluated. The dimensionality of X is (N, D), with N as + the number of points to evaluate at and D is the number of + dimensions of one X. + + Returns + ------- + np.ndarray(N,1) + Expected HV Improvement of X + """ + if len(X.shape) == 1: + X = X[:, np.newaxis] + + mean, _ = self.model.predict_marginalized(X) #Expected to be not normalized + phvi = np.zeros(len(X)) + for i, indiv in enumerate(mean): + points = list(self.population_costs) + [indiv] + hv = self.get_hypervolume(points) + phvi[i] = hv - self.population_hv + + return phvi.reshape(-1, 1) diff --git a/smac/acquisition/maximizer/local_search.py b/smac/acquisition/maximizer/local_search.py index c6f545f9a6..300a6a599f 100644 --- a/smac/acquisition/maximizer/local_search.py +++ b/smac/acquisition/maximizer/local_search.py @@ -65,7 +65,7 @@ def __init__( seed=seed, ) - self._max_steps = max_steps + self._max_steps = max_steps if max_steps is not None else np.inf self._n_steps_plateau_walk = n_steps_plateau_walk self._vectorization_min_obtain = vectorization_min_obtain self._vectorization_max_obtain = vectorization_max_obtain @@ -90,15 +90,10 @@ def _maximize( n_points: int, additional_start_points: list[tuple[float, Configuration]] | None = None, ) -> list[tuple[float, Configuration]]: - """Start a local search from the given startpoints. Iteratively collect neighbours - using Configspace.utils.get_one_exchange_neighbourhood and evaluate them. - If the new config is better than the current best, the local search is coninued from the - new config. + """Start a local search from the given startpoint. Quit if either the max number of steps is reached or - no neighbor with a higher improvement was found or the number of local steps self._n_steps_plateau_walk - for each of the starting point is depleted. - + no neighbor with an higher improvement was found. Parameters ---------- @@ -208,18 +203,16 @@ def _get_init_points_from_previous_configs( costs = self._acquisition_function.model.predict_marginalized(conf_array)[0] assert len(conf_array) == len(costs), (conf_array.shape, costs.shape) - # In case of the predictive model returning the prediction for more than one objective per configuration - # (for example multi-objective or EIPS) it is not immediately clear how to sort according to the cost - # of a configuration. Therefore, we simply follow the ParEGO approach and use a random scalarization. + sort_objectives = [costs.flatten()] if len(costs.shape) == 2 and costs.shape[1] > 1: - weights = np.array([self._rng.rand() for _ in range(costs.shape[1])]) - weights = weights / np.sum(weights) - costs = costs @ weights + sort_objectives = self._create_sort_keys(costs=costs) # From here: make argsort result to be random between equal values # http://stackoverflow.com/questions/20197990/how-to-make-argsort-result-to-be-random-between-equal-values random = self._rng.rand(len(costs)) - indices = np.lexsort((random.flatten(), costs.flatten())) # Last column is primary sort key! + + # Last column is primary sort key! + indices = np.lexsort((random.flatten(), *sort_objectives)) # Cannot use zip here because the indices array cannot index the # rand_configs list, because the second is a pure python list @@ -232,19 +225,61 @@ def _get_init_points_from_previous_configs( else: additional_start_points = [] - init_points = [] - init_points_as_set: set[Configuration] = set() - for cand in itertools.chain( + candidates = itertools.chain( configs_previous_runs_sorted, previous_configs_sorted_by_cost, additional_start_points, - ): - if cand not in init_points_as_set: - init_points.append(cand) - init_points_as_set.add(cand) + ) + init_points = self._unique_list(candidates) return init_points + def _create_sort_keys(self, costs: np.array) -> list[list[float]]: + """Sort costs by random scalarization + + In case of the predictive model returning the prediction for more than one objective per configuration + (for example multi-objective or EIPS) it is not immediately clear how to sort according to the cost + of a configuration. Therefore, we simply follow the ParEGO approach and use a random scalarization. + + Parameters + ---------- + costs : np.array + Cost(s) per config + + Returns + ------- + list[list[float]] + Sorting sequence for lexsort + """ + weights = np.array([self._rng.rand() for _ in range(costs.shape[1])]) + weights = weights / np.sum(weights) + costs = costs @ weights + sort_objectives = [costs.flatten()] + return sort_objectives + + + @staticmethod + def _unique_list(elements: list | itertools.chain) -> list: + """ + Returns the list with only unique elements while remaining the list order. + + Parameters + ---------- + elements : list | itertools.chain + + Returns + ------- + A list with unique elements with preserved order + """ + return_list = [] + return_list_as_set = set() + for e in elements: + if e not in return_list_as_set: + return_list.append(e) + return_list_as_set.add(e) + + return return_list + def _search( self, start_points: list[Configuration], @@ -302,13 +337,13 @@ def _search( local_search_steps = [0] * num_candidates # tracking the number of neighbors looked at for logging purposes neighbors_looked_at = [0] * num_candidates - # tracking the number of neighbors generated for logging purposse + # tracking the number of neighbors generated for logging purposes neighbors_generated = [0] * num_candidates # how many neighbors were obtained for the i-th local search. Important to map the individual acquisition # function values to the correct local search run obtain_n = [self._vectorization_min_obtain] * num_candidates # Tracking the time it takes to compute the acquisition function - times = [] + times_per_iteration: list[float] = [] # Set up the neighborhood generators neighborhood_iterators = [] @@ -360,11 +395,12 @@ def _search( obtain_n[i] = len(neighbors_for_i) neighbors.extend(neighbors_for_i) + logger.debug(f"Iteration {num_iters} with {np.count_nonzero(active)} active searches and {len(neighbors)} aqcuisition function calls.") if len(neighbors) != 0: start_time = time.time() acq_val = self._acquisition_function(neighbors) end_time = time.time() - times.append(end_time - start_time) + times_per_iteration.append(end_time - start_time) if np.ndim(acq_val.shape) == 0: acq_val = np.asarray([acq_val]) @@ -421,7 +457,7 @@ def _search( continue if obtain_n[i] == 0 or improved[i]: - obtain_n[i] = 2 + obtain_n[i] = self._vectorization_min_obtain else: obtain_n[i] = obtain_n[i] * 2 obtain_n[i] = min(obtain_n[i], self._vectorization_max_obtain) @@ -432,7 +468,12 @@ def _search( candidates[i] = neighbors_w_equal_acq[i][0] neighbors_w_equal_acq[i] = [] n_no_plateau_walk[i] += 1 - if n_no_plateau_walk[i] >= self._n_steps_plateau_walk: + + if n_no_plateau_walk[i] >= self._n_steps_plateau_walk or local_search_steps[i] >= self._max_steps: + message = f"Local search {i}: Stop search after walking {n_no_plateau_walk[i]} plateaus after {neighbors_looked_at[i]}." + if local_search_steps[i] >= self._max_steps: + message += f" Reached max_steps ({self._max_steps}) of local search." + logger.debug(message) active[i] = False continue @@ -442,11 +483,10 @@ def _search( ) logger.debug( - "Local searches took %s steps and looked at %s configurations. Computing the acquisition function in " - "vectorized for took %f seconds on average.", - local_search_steps, - neighbors_looked_at, - np.mean(times), + f"Local searches took {local_search_steps} steps and looked at {neighbors_looked_at} configurations." + f"Computing the acquisition function for each search took {np.sum(times_per_iteration)/num_candidates}" + f"(prev {np.mean(times_per_iteration)}) seconds on average and each acquisition function call took {times_per_iteration/np.sum(neighbors_looked_at)} seconds on average." + f"In total the whole procedure took {np.sum(times_per_iteration)} seconds to look at {np.sum(neighbors_looked_at)} configurations." ) return [(a, i) for a, i in zip(acq_val_candidates, candidates)] diff --git a/smac/acquisition/maximizer/multi_objective_search.py b/smac/acquisition/maximizer/multi_objective_search.py new file mode 100644 index 0000000000..13283b7322 --- /dev/null +++ b/smac/acquisition/maximizer/multi_objective_search.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from typing import Any + +import itertools +import time + +from pygmo import fast_non_dominated_sorting + +import numpy as np +from ConfigSpace import Configuration, ConfigurationSpace +from ConfigSpace.exceptions import ForbiddenValueError + +from smac.acquisition.function import AbstractAcquisitionFunction +from smac.acquisition.maximizer.local_search import LocalSearch +from smac.acquisition.maximizer.local_and_random_search import LocalAndSortedRandomSearch +from smac.utils.configspace import ( + convert_configurations_to_array, + get_one_exchange_neighbourhood, +) +from smac.utils.logging import get_logger + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + + +class MOLocalSearch(LocalSearch): + def _get_initial_points( + self, + previous_configs: list[Configuration], + n_points: int, + additional_start_points: list[tuple[float, Configuration]] | None, + ) -> list[Configuration]: + """Get initial points to start search from. + + If we already have a population, add those to the initial points. + + Parameters + ---------- + previous_configs : list[Configuration] + Previous configuration (e.g., from the runhistory). + n_points : int + Number of initial points to be generated. + additional_start_points : list[tuple[float, Configuration]] | None + Additional starting points. + + Returns + ------- + list[Configuration] + A list of initial points/configurations. + """ + init_points = super()._get_initial_points(previous_configs=previous_configs, n_points=n_points, additional_start_points=additional_start_points) + + # Add population to Local search + # TODO where is population saved? update accordingly + if len(stats.population) > 0: + population = [runhistory.ids_config[confid] for confid in stats.population] + init_points = self._unique_list(itertools.chain(population, init_points)) + return init_points + + def _create_sort_keys(self, costs: np.array) -> list[list[float]]: + """Non-Dominated Sorting of Costs + + In case of the predictive model returning the prediction for more than one objective per configuration + (for example multi-objective or EIPS) we sort here based on the dominance order. In each front + configurations are sorted on the number of points they dominate overall. + + Parameters + ---------- + costs : np.array + Cost(s) per config + + Returns + ------- + list[list[float]] + Sorting sequence for lexsort + """ + _, domination_list, _, non_domination_rank = fast_non_dominated_sorting(costs) + domination_list = [len(i) for i in domination_list] + sort_objectives = [domination_list, non_domination_rank] # Last column is primary sort key! + return sort_objectives + + +class MOLocalAndSortedRandomSearch(LocalAndSortedRandomSearch): + """Local and Random Search for Multi-Objective + + This optimizer performs local search from the previous best points according, to the acquisition function, uses the + acquisition function to sort randomly sampled configurations. Random configurations are interleaved by the main SMAC + code. + + Parameters + ---------- + configspace : ConfigurationSpace + acquisition_function : AbstractAcquisitionFunction | None, defaults to None + challengers : int, defaults to 5000 + Number of challengers. + max_steps: int | None, defaults to None + [LocalSearch] Maximum number of steps that the local search will perform. + n_steps_plateau_walk: int, defaults to 10 + [LocalSearch] number of steps during a plateau walk before local search terminates. + local_search_iterations: int, defauts to 10 + [Local Search] number of local search iterations. + seed : int, defaults to 0 + """ + + def __init__( + self, + configspace: ConfigurationSpace, + acquisition_function: AbstractAcquisitionFunction | None = None, + challengers: int = 5000, + max_steps: int | None = None, + n_steps_plateau_walk: int = 10, + local_search_iterations: int = 10, + seed: int = 0, + ) -> None: + super().__init__( + configspace=configspace, + acquisition_function=acquisition_function, + challengers=challengers, + max_steps=max_steps, + n_steps_plateau_walk=n_steps_plateau_walk, + local_search_iterations=local_search_iterations, + seed=seed, + ) + + self.local_search = MOLocalSearch( + configspace=configspace, + acquisition_function=acquisition_function, + challengers=challengers, + max_steps=max_steps, + n_steps_plateau_walk=n_steps_plateau_walk, + seed=seed + ) diff --git a/smac/facade/abstract_facade.py b/smac/facade/abstract_facade.py index a5db5ab521..4d11e68ab1 100644 --- a/smac/facade/abstract_facade.py +++ b/smac/facade/abstract_facade.py @@ -415,11 +415,15 @@ def get_multi_objective_algorithm(scenario: Scenario) -> AbstractMultiObjectiveA def get_config_selector( scenario: Scenario, *, - retrain_after: int = 8, + retrain_after: int | None = 8, + retrain_wallclock_ratio: int | None = None, retries: int = 16, ) -> ConfigSelector: """Returns the default configuration selector.""" - return ConfigSelector(scenario, retrain_after=retrain_after, retries=retries) + return ConfigSelector(scenario, + retrain_after=retrain_after, + retrain_wallclock_ratio=retrain_wallclock_ratio, + retries=retries) def _get_optimizer(self) -> SMBO: """Fills the SMBO with all the pre-initialized components.""" diff --git a/smac/facade/algorithm_configuration_facade.py b/smac/facade/algorithm_configuration_facade.py index a82e2f92ca..1756ffe69f 100644 --- a/smac/facade/algorithm_configuration_facade.py +++ b/smac/facade/algorithm_configuration_facade.py @@ -15,6 +15,7 @@ from smac.runhistory.encoder.encoder import RunHistoryEncoder from smac.scenario import Scenario from smac.utils.logging import get_logger +from smac.intensifier.mixins import intermediate_update, intermediate_decision __copyright__ = "Copyright 2022, automl.org" __license__ = "3-clause BSD" @@ -115,7 +116,12 @@ def get_intensifier( max_incumbents : int, defaults to 10 How many incumbents to keep track of in the case of multi-objective. """ - return Intensifier( + class NewIntensifier(intermediate_decision.NewCostDominatesOldCost, + intermediate_update.ClosestIncumbentComparison, + Intensifier): + pass + + return NewIntensifier( scenario=scenario, max_config_calls=max_config_calls, max_incumbents=max_incumbents, diff --git a/smac/facade/multi_objective_facade.py b/smac/facade/multi_objective_facade.py new file mode 100644 index 0000000000..ccb842525a --- /dev/null +++ b/smac/facade/multi_objective_facade.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from ConfigSpace import Configuration + +from smac.acquisition.function.expected_improvement import EI +from smac.acquisition.function.expected_hypervolume import PHVI +from smac.facade.abstract_facade import AbstractFacade +from smac.initial_design.default_design import DefaultInitialDesign +from smac.intensifier.intensifier import Intensifier +from smac.intensifier.multi_objective_intensifier import MOIntensifier +from smac.intensifier.mixins import intermediate_update, intermediate_decision, update_incumbent +from smac.model.random_forest.random_forest import RandomForest +from smac.model.multi_objective_model import MultiObjectiveModel +from smac.multi_objective.aggregation_strategy import NoAggregationStrategy +from smac.random_design.probability_design import ProbabilityRandomDesign +from smac.runhistory.encoder.encoder import RunHistoryEncoder +from smac.runhistory.encoder.log_encoder import RunHistoryLogEncoder +from smac.scenario import Scenario +from smac.utils.logging import get_logger +from smac.acquisition.maximizer.multi_objective_search import MOLocalAndSortedRandomSearch + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + + +logger = get_logger(__name__) + + +class MultiObjectiveFacade(AbstractFacade): + @staticmethod + def get_model( # type: ignore + scenario: Scenario, + *, + n_trees: int = 10, + ratio_features: float = 5.0 / 6.0, + min_samples_split: int = 3, + min_samples_leaf: int = 3, + max_depth: int = 20, + bootstrapping: bool = True, + pca_components: int = 4, + ) -> RandomForest: + """Returns a random forest as surrogate model. + + Parameters + ---------- + n_trees : int, defaults to 10 + The number of trees in the random forest. + ratio_features : float, defaults to 5.0 / 6.0 + The ratio of features that are considered for splitting. + min_samples_split : int, defaults to 3 + The minimum number of data points to perform a split. + min_samples_leaf : int, defaults to 3 + The minimum number of data points in a leaf. + max_depth : int, defaults to 20 + The maximum depth of a single tree. + bootstrapping : bool, defaults to True + Enables bootstrapping. + pca_components : float, defaults to 4 + Number of components to keep when using PCA to reduce dimensionality of instance features. + """ + + models = [] + for objective in scenario.objectives: + models.append( + RandomForest( + configspace=scenario.configspace, + n_trees=n_trees, + ratio_features=ratio_features, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + bootstrapping=bootstrapping, + log_y=False, + instance_features=scenario.instance_features, + pca_components=pca_components, + seed=scenario.seed, + ) + ) + + return MultiObjectiveModel(models=models, objectives=scenario.objectives) + + @staticmethod + def get_intensifier( # type: ignore + scenario: Scenario, + *, + max_config_calls: int = 2000, + max_incumbents: int = 10, + ) -> Intensifier: + """Returns ``MOIntensifier`` as intensifier. Uses the default configuration for ``race_against``. + + Parameters + ---------- + scenario : Scenario + max_config_calls : int, defaults to 2000 + Maximum number of configuration evaluations. Basically, how many instance-seed keys should be max evaluated + for a configuration. + max_incumbents : int, defaults to 10 + How many incumbents to keep track of in the case of multi-objective. + """ + class NewIntensifier(intermediate_decision.NewCostDominatesOldCost, + intermediate_update.ClosestIncumbentComparison, + MOIntensifier): + pass + + return NewIntensifier( + scenario=scenario, + max_config_calls=max_config_calls, + max_incumbents=max_incumbents, + ) + + @staticmethod + # TODO update acquisition function with EIHV and PIHV + def get_acquisition_function( # type: ignore + scenario: Scenario, + *, + xi: float = 0.0, + ) -> EHVI: + """Returns an Expected Improvement acquisition function. + + Parameters + ---------- + scenario : Scenario + xi : float, defaults to 0.0 + Controls the balance between exploration and exploitation of the + acquisition function. + """ + return PHVI() + + @staticmethod + def get_acquisition_maximizer( # type: ignore + scenario: Scenario, + ) -> MOLocalAndSortedRandomSearch: + """Returns local and sorted random search as acquisition maximizer.""" + optimizer = MOLocalAndSortedRandomSearch( + scenario.configspace, + seed=scenario.seed, + ) + + return optimizer + + @staticmethod + # TODO update initial design to LHD + def get_initial_design( # type: ignore + scenario: Scenario, + *, + additional_configs: list[Configuration] = [], + ) -> DefaultInitialDesign: + """Returns an initial design, which returns the default configuration. + + Parameters + ---------- + additional_configs: list[Configuration], defaults to [] + Adds additional configurations to the initial design. + """ + return DefaultInitialDesign( + scenario=scenario, + additional_configs=additional_configs, + ) + + @staticmethod + def get_random_design( # type: ignore + scenario: Scenario, + *, + probability: float = 0.5, + ) -> ProbabilityRandomDesign: + """Returns ``ProbabilityRandomDesign`` for interleaving configurations. + + Parameters + ---------- + probability : float, defaults to 0.5 + Probability that a configuration will be drawn at random. + """ + return ProbabilityRandomDesign(probability=probability, seed=scenario.seed) + + @staticmethod + def get_multi_objective_algorithm( # type: ignore + scenario: Scenario, + ) -> NoAggregationStrategy: + """Returns the mean aggregation strategy for the multi objective algorithm. + + Parameters + ---------- + scenario : Scenario + objective_weights : list[float] | None, defaults to None + Weights for averaging the objectives in a weighted manner. Must be of the same length as the number of + objectives. + """ + return NoAggregationStrategy() + + @staticmethod + def get_runhistory_encoder(scenario: Scenario) -> RunHistoryEncoder: + """Returns the default runhistory encoder with native multi objective support enabled.""" + return RunHistoryEncoder(scenario, native_multi_objective=True, normalize=False) + # return RunHistoryLogEncoder(scenario, native_multi_objective=True, normalize=False) diff --git a/smac/intensifier/abstract_intensifier.py b/smac/intensifier/abstract_intensifier.py index b944867273..3e706f4ab6 100644 --- a/smac/intensifier/abstract_intensifier.py +++ b/smac/intensifier/abstract_intensifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from abc import abstractmethod from typing import Any, Callable, Iterator @@ -56,13 +57,13 @@ def __init__( scenario: Scenario, n_seeds: int | None = None, max_config_calls: int | None = None, - max_incumbents: int = 10, + max_incumbents: int = 10, # TODO set in MO facade seed: int | None = None, ): self._scenario = scenario self._config_selector: ConfigSelector | None = None self._config_generator: Iterator[ConfigSelector] | None = None - self._runhistory: RunHistory | None = None + self._runhistory: RunHistory | None = RunHistory if seed is None: seed = self._scenario.seed @@ -88,11 +89,20 @@ def reset(self) -> None: self._instance_seed_keys_validation: list[InstanceSeedKey] | None = None # Incumbent variables - self._incumbents: list[Configuration] = [] + self.incumbents: list[Configuration] = [] self._incumbents_changed = 0 self._rejected_config_ids: list[int] = [] self._trajectory: list[TrajectoryItem] = [] + @property + def incumbents(self) -> list[Configuration]: + return self._incumbents + + @incumbents.setter + def incumbents(self, incumbents: list[Configuration]) -> None: + self._incumbents = incumbents + self.runhistory.incumbents = incumbents + @property def meta(self) -> dict[str, Any]: """Returns the meta data of the created object.""" @@ -353,14 +363,14 @@ def get_incumbent(self) -> Configuration | None: if self._scenario.count_objectives() > 1: raise ValueError("Cannot get a single incumbent for multi-objective optimization.") - if len(self._incumbents) == 0: + if len(self.incumbents) == 0: return None - assert len(self._incumbents) == 1 - return self._incumbents[0] + assert len(self.incumbents) == 1 + return self.incumbents[0] def get_incumbents(self, sort_by: str | None = None) -> list[Configuration]: - """Returns the incumbents (points on the pareto front) of the runhistory as copy. In case of a single-objective + """Returns the incumbents (points on the Pareto front) of the runhistory as copy. In case of a single-objective optimization, only one incumbent (if is) is returned. Returns @@ -374,11 +384,11 @@ def get_incumbents(self, sort_by: str | None = None) -> list[Configuration]: rh = self.runhistory if sort_by == "cost": - return list(sorted(self._incumbents, key=lambda config: rh._cost_per_config[rh.get_config_id(config)])) + return list(sorted(self.incumbents, key=lambda config: rh._cost_per_config[rh.get_config_id(config)])) elif sort_by == "num_trials": - return list(sorted(self._incumbents, key=lambda config: len(rh.get_trials(config)))) + return list(sorted(self.incumbents, key=lambda config: len(rh.get_trials(config)))) elif sort_by is None: - return list(self._incumbents) + return list(self.incumbents) else: raise ValueError(f"Unknown sort_by value: {sort_by}.") @@ -396,7 +406,7 @@ def get_incumbent_instance_seed_budget_keys(self, compare: bool = False) -> list incumbents = self.get_incumbents() if len(incumbents) > 0: - # We want to calculate the smallest set of trials that is used by all incumbents + # We want to calculate the largest set of trials that is used by all incumbents # Reason: We can not fairly compare otherwise incumbent_isb_keys = [self.get_instance_seed_budget_keys(incumbent, compare) for incumbent in incumbents] instances = list(set.intersection(*map(set, incumbent_isb_keys))) # type: ignore @@ -419,14 +429,15 @@ def get_incumbent_instance_seed_budget_key_differences(self, compare: bool = Fal return [] # Compute the actual differences - intersection_isb_keys = set.intersection(*map(set, incumbent_isb_keys)) # type: ignore - union_isb_keys = set.union(*map(set, incumbent_isb_keys)) # type: ignore - incumbent_isb_keys = list(union_isb_keys - intersection_isb_keys) # type: ignore + intersection_isb_keys = set.intersection(*map(set, incumbent_isb_keys)) + union_isb_keys = set.union(*map(set, incumbent_isb_keys)) + incumbent_isb_keys_differences = list(union_isb_keys - intersection_isb_keys) + # incumbent_isb_keys = list(set.difference(*map(set, incumbent_isb_keys))) # type: ignore - if len(incumbent_isb_keys) == 0: + if len(incumbent_isb_keys_differences) == 0: return [] - return incumbent_isb_keys # type: ignore + return incumbent_isb_keys_differences # type: ignore return [] @@ -453,62 +464,109 @@ def on_tell_end(self, smbo: smac.main.smbo.SMBO, info: TrialInfo, value: TrialVa return RunHistoryCallback(self) - def update_incumbents(self, config: Configuration) -> None: - """Updates the incumbents. This method is called everytime a trial is added to the runhistory. Since only - the affected config and the current incumbents are used, this method is very efficient. Furthermore, a - configuration is only considered incumbent if it has a better performance on all incumbent instances. - - Crucially, if there is no incumbent (at the start) then, the first configuration assumes - incumbent status. For the next configuration, we need to check if the configuration - is better on all instances that have been evaluated for the incumbent. If this is the - case, then we can replace the incumbent. Otherwise, a) we need to requeue the config to - obtain the missing instance-seed-budget combination or b) mark this configuration as - inferior ("rejected") to not consider it again. The comparison behaviour is controlled by - self.get_instance_seed_budget_keys() and self.get_incumbent_instance_seed_budget_keys(). - - Notably, this method is written to support both multi-fidelity and multi-objective - optimization. While the get_instance_seed_budget_keys() method and - self.get_incumbent_instance_seed_budget_keys() are used for the multi-fidelity behaviour, - calculate_pareto_front() is used as a hard coded way to support multi-objective - optimization, including the single objective as special case. calculate_pareto_front() - is called on the set of all (in case of MO) incumbents amended with the challenger - configuration, provided it has a sufficient overlap in seed-instance-budget combinations. - - Lastly, if we have a self._max_incumbents and the pareto front provides more than this - specified amount, we cut the incumbents using crowding distance. + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """Checks if the configuration should be evaluated against the incumbent while it + did not run on all the trails the incumbents did. + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + + return False + + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) + + + logger.debug( + f"Perform intermediate comparions of config {config_hash} with incumbents to see if it is worse" + ) + # TODO perform comparison with incumbent on current instances. + # Check if the config with these number of trials is part of the Pareto front + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug("Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + # Only compare domination between one incumbent (as relaxation measure) + iid = self._rng.choice(len(incumbents)) + incumbents = [incumbents[iid], config] + # incumbents.append(config) + + # Only the trials of the challenger + all_incumbent_isb_keys = [config_isb_keys for _ in incumbents] + + new_incumbents = self._calculate_pareto_front(self.runhistory, incumbents, all_incumbent_isb_keys) + + return config in new_incumbents + + def _update_incumbent(self, config: Configuration) -> list[Configuration]: + """Updates the incumbent with the config (which can be the challenger) + + Parameters + ---------- + config: Configuration + + Returns + ------- """ rh = self.runhistory - # What happens if a config was rejected, but it appears again? Give it another try even if it - # has already been evaluated? Yes! + incumbents = self.get_incumbents() - # Associated trials and id - config_isb_keys = self.get_instance_seed_budget_keys(config) - config_id = rh.get_config_id(config) + if config not in incumbents: + incumbents.append(config) + + isb_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) + all_incumbent_isb_keys = [isb_keys for _ in range(len(incumbents))] + + # We compare the incumbents now and only return the ones on the Pareto front + # _calculate_pareto_front returns only non-dominated points + new_incumbents = self._calculate_pareto_front(rh, incumbents, all_incumbent_isb_keys) + return new_incumbents + + def update_incumbents(self, config: Configuration) -> None: + incumbents = self.get_incumbents() config_hash = get_config_hash(config) - # We skip updating incumbents if no instances are available + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) + + #Check if config holds keys # Note: This is especially the case if trials of a config are still running # because if trials are running, the runhistory does not update the trials in the fast data structure if len(config_isb_keys) == 0: logger.debug(f"No relevant instances evaluated for config {config_hash}. Updating incumbents is skipped.") return - # Now we get the incumbents and see which trials have been used - incumbents = self.get_incumbents() - incumbent_ids = [rh.get_config_id(c) for c in incumbents] - # Find the lowest intersection of instance-seed-budget keys for all incumbents. - incumbent_isb_keys = self.get_incumbent_instance_seed_budget_keys() - - # Save for later - previous_incumbents = incumbents.copy() - previous_incumbent_ids = incumbent_ids.copy() - - # Little sanity check here for consistency if len(incumbents) > 0: assert incumbent_isb_keys is not None assert len(incumbent_isb_keys) > 0 + # Check if incumbent exists # If there are no incumbents at all, we just use the new config as new incumbent # Problem: We can add running incumbents if len(incumbents) == 0: # incumbent_isb_keys is None and len(incumbents) == 0: @@ -518,87 +576,51 @@ def update_incumbents(self, config: Configuration) -> None: # Nothing else to do return - # Comparison keys - # This one is a bit tricky: We would have problems if we compare with budgets because we might have different - # scenarios (depending on the incumbent selection specified in Successive Halving). - # 1) Any budget/highest observed budget: We want to get rid of the budgets because if we know it is calculated - # on the same instance-seed already then we are ready to go. Imagine we would check for the same budgets, - # then the configs can not be compared although the user does not care on which budgets configurations have - # been evaluated. - # 2) Highest budget: We only want to compare the configs if they are evaluated on the highest budget. - # Here we do actually care about the budgets. Please see the ``get_instance_seed_budget_keys`` method from - # Successive Halving to get more information. - # Noitce: compare=True only takes effect when subclass implemented it. -- e.g. in SH it - # will remove the budgets from the keys. - config_isb_comparison_keys = self.get_instance_seed_budget_keys(config, compare=True) - # Find the lowest intersection of instance-seed-budget keys for all incumbents. - config_incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) - - # Now we have to check if the new config has been evaluated on the same keys as the incumbents - if not all([key in config_isb_comparison_keys for key in config_incumbent_isb_comparison_keys]): - # We can not tell if the new config is better/worse than the incumbents because it has not been - # evaluated on the necessary trials - logger.debug( - f"Could not compare config {config_hash} with incumbents because it's evaluated on " - f"different trials." - ) - - # The config has to go to a queue now as it is a challenger and a potential incumbent + #Check if config isb is subset of incumbents + # if not all([isb_key in incumbent_isb_keys for isb_key in config_isb_keys]): + # # If the config is part of the incumbents this could happen + # logger.info(f"Config {config_hash} did run on more instances than the incumbent. Cannot make a proper comparison.") + # return + + # Config did not run on all isb keys of incumbent + # Now we have to check if we should continue with this configuration + if not set(config_isb_keys) == set(incumbent_isb_keys): + # Config did not run on all trials + if self._check_for_intermediate_comparison(config): + if not self._intermediate_comparison(config): + logger.debug(f"Rejected config {config_hash} in an intermediate comparison on {len(config_isb_keys)} trials.") + self._add_rejected_config(config) return - else: - # If all instances are available and the config is incumbent and even evaluated on more trials - # then there's nothing we can do - if config in incumbents and len(config_isb_keys) > len(incumbent_isb_keys): - logger.debug( - "Config is already an incumbent but can not be compared to other incumbents because " - "the others are missing trials." - ) - return - - # Add config to incumbents so that we compare only the new config and existing incumbents - if config not in incumbents: - incumbents.append(config) - incumbent_ids.append(config_id) - # Now we get all instance-seed-budget keys for each incumbent (they might be different when using budgets) - all_incumbent_isb_keys = [] - for incumbent in incumbents: - all_incumbent_isb_keys.append(self.get_instance_seed_budget_keys(incumbent)) + # Config did run on all isb keys of incumbent + # Here we really update the incumbent by: + # 1. Removing incumbents that are now dominated by another configuration in the incumbent + # 2. Add in the challenger to the incumbent + rh = self.runhistory - # We compare the incumbents now and only return the ones on the pareto front - new_incumbents = calculate_pareto_front(rh, incumbents, all_incumbent_isb_keys) + previous_incumbents = copy.copy(incumbents) + previous_incumbent_ids = [rh.get_config_id(c) for c in previous_incumbents] + new_incumbents = self._update_incumbent(config) new_incumbent_ids = [rh.get_config_id(c) for c in new_incumbents] - if len(previous_incumbents) == len(new_incumbents): - if previous_incumbents == new_incumbents: - # No changes in the incumbents - self._remove_rejected_config(config_id) - return - else: - # In this case, we have to determine which config replaced which incumbent and reject it - removed_incumbent_id = list(set(previous_incumbent_ids) - set(new_incumbent_ids))[0] - removed_incumbent_hash = get_config_hash(rh.get_config(removed_incumbent_id)) - self._add_rejected_config(removed_incumbent_id) - - if removed_incumbent_id == config_id: - logger.debug( - f"Rejected config {config_hash} because it is not better than the incumbents on " - f"{len(config_isb_keys)} instances." - ) - else: - self._remove_rejected_config(config_id) - logger.info( - f"Added config {config_hash} and rejected config {removed_incumbent_hash} as incumbent because " - f"it is not better than the incumbents on {len(config_isb_keys)} instances:" - ) - print_config_changes(rh.get_config(removed_incumbent_id), config, logger=logger) + # Update trajectory + if previous_incumbents == new_incumbents: # Only happens with incumbent config + self._remove_rejected_config(config) + return + elif len(previous_incumbents) == len(new_incumbents): + # In this case, we have to determine which config replaced which incumbent and reject it + # We will remove the oldest configuration (the one with the lowest id) because + # set orders the ids ascending. + self._remove_incumbent(config=config, + previous_incumbent_ids=previous_incumbent_ids, + new_incumbent_ids=new_incumbent_ids) elif len(previous_incumbents) < len(new_incumbents): # Config becomes a new incumbent; nothing is rejected in this case - self._remove_rejected_config(config_id) + self._remove_rejected_config(config) logger.info( f"Config {config_hash} is a new incumbent. " f"Total number of incumbents: {len(new_incumbents)}." ) - else: + else: # len(previous_incumbents) > len(new_incumbents) # There might be situations that the incumbents might be removed because of updated cost information of # config for incumbent in previous_incumbents: @@ -612,20 +634,279 @@ def update_incumbents(self, config: Configuration) -> None: # Cut incumbents: We only want to keep a specific number of incumbents # We use the crowding distance for that if len(new_incumbents) > self._max_incumbents: - new_incumbents = sort_by_crowding_distance(rh, new_incumbents, all_incumbent_isb_keys) - new_incumbents = new_incumbents[: self._max_incumbents] + all_incumbent_isb_keys = [incumbent_isb_keys for i in range(len(new_incumbents))] + new_incumbents = self._cut_incumbents(new_incumbents, all_incumbent_isb_keys) + #TODO JG adjust. Other option: statistical test or HV (SMS-EMOA reduce function) + + self._update_trajectory(new_incumbents) - # or random? - # idx = self._rng.randint(0, len(new_incumbents)) - # del new_incumbents[idx] - # del new_incumbent_ids[idx] + # def update_incumbents(self, config: Configuration) -> None: + # """Updates the incumbents. This method is called everytime a trial is added to the runhistory. Since only + # the affected config and the current incumbents are used, this method is very efficient. Furthermore, a + # configuration is only considered incumbent if it has a better performance on all incumbent instances. + # + # Crucially, if there is no incumbent (at the start) then, the first configuration assumes + # incumbent status. For the next configuration, we need to check if the configuration + # is better on all instances that have been evaluated for the incumbent. If this is the + # case, then we can replace the incumbent. Otherwise, a) we need to requeue the config to + # obtain the missing instance-seed-budget combination or b) mark this configuration as + # inferior ("rejected") to not consider it again. The comparison behaviour is controlled by + # self.get_instance_seed_budget_keys() and self.get_incumbent_instance_seed_budget_keys(). + # + # Notably, this method is written to support both multi-fidelity and multi-objective + # optimization. While the get_instance_seed_budget_keys() method and + # self.get_incumbent_instance_seed_budget_keys() are used for the multi-fidelity behaviour, + # calculate_pareto_front() is used as a hard coded way to support multi-objective + # optimization, including the single objective as special case. calculate_pareto_front() + # is called on the set of all (in case of MO) incumbents amended with the challenger + # configuration, provided it has a sufficient overlap in seed-instance-budget combinations. + # + # Lastly, if we have a self._max_incumbents and the pareto front provides more than this + # specified amount, we cut the incumbents using crowding distance. + # """ + # rh = self.runhistory + # + # # What happens if a config was rejected, but it appears again? Give it another try even if it + # # has already been evaluated? Yes! + # + # #TODO what to do when config is part of the incumbent? + # + # # Associated trials and id + # config_isb_keys = self.get_instance_seed_budget_keys(config) + # config_id = rh.get_config_id(config) + # config_hash = get_config_hash(config) + # + # # We skip updating incumbents if no instances are available + # # Note: This is especially the case if trials of a config are still running + # # because if trials are running, the runhistory does not update the trials in the fast data structure + # if len(config_isb_keys) == 0: + # logger.debug(f"No relevant instances evaluated for config {config_hash}. Updating incumbents is skipped.") + # return + # + # # Now we get the incumbents and see which trials have been used + # incumbents = self.get_incumbents() + # incumbent_ids = [rh.get_config_id(c) for c in incumbents] + # # Find the lowest intersection of instance-seed-budget keys for all incumbents. + # incumbent_isb_keys = self.get_incumbent_instance_seed_budget_keys() + # + # # Save for later + # previous_incumbents = incumbents.copy() + # previous_incumbent_ids = incumbent_ids.copy() + # + # # Little sanity check here for consistency + # if len(incumbents) > 0: + # assert incumbent_isb_keys is not None + # assert len(incumbent_isb_keys) > 0 + # + # # If there are no incumbents at all, we just use the new config as new incumbent + # # Problem: We can add running incumbents + # if len(incumbents) == 0: # incumbent_isb_keys is None and len(incumbents) == 0: + # logger.info(f"Added config {config_hash} as new incumbent because there are no incumbents yet.") + # self._update_trajectory([config]) + # + # # Nothing else to do + # return + # + # # Comparison keys + # # This one is a bit tricky: We would have problems if we compare with budgets because we might have different + # # scenarios (depending on the incumbent selection specified in Successive Halving). + # # 1) Any budget/highest observed budget: We want to get rid of the budgets because if we know it is calculated + # # on the same instance-seed already then we are ready to go. Imagine we would check for the same budgets, + # # then the configs can not be compared although the user does not care on which budgets configurations have + # # been evaluated. + # # 2) Highest budget: We only want to compare the configs if they are evaluated on the highest budget. + # # Here we do actually care about the budgets. Please see the ``get_instance_seed_budget_keys`` method from + # # Successive Halving to get more information. + # # Noitce: compare=True only takes effect when subclass implemented it. -- e.g. in SH it + # # will remove the budgets from the keys. + # config_isb_comparison_keys = self.get_instance_seed_budget_keys(config, compare=True) + # # Find the lowest intersection of instance-seed-budget keys for all incumbents. + # config_incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) # Intersection + # + # # Now we have to check if the new config has been evaluated on the same keys as the incumbents + # # TODO If the config is part of the incumbent then it should always be a subset of the intersection + # if not all([key in config_isb_comparison_keys for key in config_incumbent_isb_comparison_keys]): + # # We can not tell if the new config is better/worse than the incumbents because it has not been + # # evaluated on the necessary trials + # + # # TODO JG add procedure to check if intermediate comparison + # if self._check_for_intermediate_comparison(config): + # if not self._intermediate_comparison(config): + # #Reject config + # logger.debug(f"Rejected config {config_hash} in an intermediate comparison because it " + # f"is dominated by a randomly sampled config from the incumbents on " + # f"{len(config_isb_keys)} trials.") + # self._add_rejected_config(config) + # + # return + # + # else: + # #TODO + # + # logger.debug( + # f"Could not compare config {config_hash} with incumbents because it's evaluated on " + # f"different trials." + # ) + # + # # The config has to go to a queue now as it is a challenger and a potential incumbent + # return + # else: + # # If all instances are available and the config is incumbent and even evaluated on more trials + # # then there's nothing we can do + # # TODO JG: Will always be false, because the incumbent with the smallest number of trials has been ran. + # # TODO JG: Hence: len(config_isb_keys) == len(incumbent_isb_keys) + # if config in incumbents and len(config_isb_keys) > len(incumbent_isb_keys): + # logger.debug( + # "Config is already an incumbent but can not be compared to other incumbents because " + # "the others are missing trials." + # ) + # return + # + # if self._final_comparison(config): + # + # + # # Add config to incumbents so that we compare only the new config and existing incumbents + # if config not in incumbents: + # incumbents.append(config) + # incumbent_ids.append(config_id) + # + # # Now we get all instance-seed-budget keys for each incumbent (they might be different when using budgets) + # all_incumbent_isb_keys = [] + # for incumbent in incumbents: + # # all_incumbent_isb_keys.append(self.get_instance_seed_budget_keys(incumbent)) + # all_incumbent_isb_keys.append(self.get_incumbent_instance_seed_budget_keys()) # !!!!! + # + # #TODO JG it is guaruanteed that the challenger has ran on the intersection of isb_keys + # # of the incumbents, however this is not the case in this part of the code. + # # Here, all the runs of each incumbent used. Maybe the intensifier ensures that the incumbents + # # have ran on the same isb keys in the first place? + # # FIXED IN LINE 580 + # + # #TODO JG get intersection for all incumbent_isb_keys and check if it breaks budget. + # + # # We compare the incumbents now and only return the ones on the Pareto front + # new_incumbents = self._calculate_pareto_front(rh, incumbents, all_incumbent_isb_keys) + # new_incumbent_ids = [rh.get_config_id(c) for c in new_incumbents] + # + # if len(previous_incumbents) == len(new_incumbents): + # if previous_incumbents == new_incumbents: + # # No changes in the incumbents + # self._remove_rejected_config(config_id) # This means that the challenger is not rejected!! + # return + # else: + # # In this case, we have to determine which config replaced which incumbent and reject it + # # We will remove the oldest configuration (the one with the lowest id) because + # # set orders the ids ascending. + # self._remove_incumbent(config=config, previous_incumbent_ids=previous_incumbent_ids, new_incumbent_ids=new_incumbent_ids) + # elif len(previous_incumbents) < len(new_incumbents): + # # Config becomes a new incumbent; nothing is rejected in this case + # self._remove_rejected_config(config_id) + # logger.info( + # f"Config {config_hash} is a new incumbent. " f"Total number of incumbents: {len(new_incumbents)}." + # ) + # else: # len(previous_incumbents) > len(new_incumbents) + # # There might be situations that the incumbents might be removed because of updated cost information of + # # config + # for incumbent in previous_incumbents: + # if incumbent not in new_incumbents: + # self._add_rejected_config(incumbent) + # logger.debug( + # f"Removed incumbent {get_config_hash(incumbent)} because of the updated costs from config " + # f"{config_hash}." + # ) + # + # # Cut incumbents: We only want to keep a specific number of incumbents + # # We use the crowding distance for that + # if len(new_incumbents) > self._max_incumbents: + # new_incumbents = self._cut_incumbents(new_incumbents, all_incumbent_isb_keys) + # #TODO JG adjust. Other option: statistical test or HV (SMS-EMOA reduce function) + # + # self._update_trajectory(new_incumbents) + + def _cut_incumbents(self, incumbent_ids: list[int], all_incumbent_isb_keys: list[list[InstanceSeedBudgetKey]]) -> list[int]: + new_incumbents = sort_by_crowding_distance(self.runhistory, incumbent_ids, all_incumbent_isb_keys) + new_incumbents = new_incumbents[: self._max_incumbents] + + # or random? + # idx = self._rng.randint(0, len(new_incumbents)) + # del new_incumbents[idx] + # del new_incumbent_ids[idx] + + logger.info( + f"Removed one incumbent using crowding distance because more than {self._max_incumbents} are " + "available." + ) + + return new_incumbents + + def _remove_incumbent(self, config: Configuration, previous_incumbent_ids: list[int], new_incumbent_ids: list[int]) -> None: + """Remove incumbents if population is too big + + If new and old incumbents differ. + Remove the oldest (the one with the lowest id) from the set of new and old incumbents. + If the current config is not discarded, it is added to the new incumbents. + + Parameters + ---------- + config : Configuration + Newly evaluated trial + previous_incumbent_ids : list[int] + Incumbents before + new_incumbent_ids : list[int] + Incumbents considering/maybe including config + """ + assert len(previous_incumbent_ids) == len(new_incumbent_ids) + assert previous_incumbent_ids != new_incumbent_ids + rh = self.runhistory + config_isb_keys = self.get_instance_seed_budget_keys(config) + config_id = rh.get_config_id(config) + config_hash = get_config_hash(config) + + removed_incumbent_id = list(set(previous_incumbent_ids) - set(new_incumbent_ids))[0] + removed_incumbent_hash = get_config_hash(rh.get_config(removed_incumbent_id)) + self._add_rejected_config(removed_incumbent_id) + + if removed_incumbent_id == config_id: + logger.debug( + f"Rejected config {config_hash} because it is not better than the incumbents on " + f"{len(config_isb_keys)} instances." + ) + else: + self._remove_rejected_config(config_id) logger.info( - f"Removed one incumbent using crowding distance because more than {self._max_incumbents} are " - "available." + f"Added config {config_hash} and rejected config {removed_incumbent_hash} as incumbent because " + f"it is not better than the incumbents on {len(config_isb_keys)} instances:" ) + print_config_changes(rh.get_config(removed_incumbent_id), config, logger=logger) - self._update_trajectory(new_incumbents) + def _calculate_pareto_front( + self, + runhistory: RunHistory, + configs: list[Configuration], + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], + ) -> list[Configuration]: + """Compares the passed configurations and returns only the ones on the pareto front. + + Parameters + ---------- + runhistory : RunHistory + The runhistory containing the given configurations. + configs : list[Configuration] + The configurations from which the Pareto front should be computed. + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]] + The instance-seed budget keys for the configurations on the basis of which the Pareto front should be computed. + + Returns + ------- + pareto_front : list[Configuration] + The pareto front computed from the given configurations. + """ + return calculate_pareto_front( + runhistory=runhistory, + configs=configs, + config_instance_seed_budget_keys=config_instance_seed_budget_keys, + ) @abstractmethod def __iter__(self) -> Iterator[TrialInfo]: @@ -643,6 +924,25 @@ def set_state(self, state: dict[str, Any]) -> None: """Sets the state of the intensifier. Used to restore the state of the intensifier when continuing a run.""" pass + def get_save_data(self) -> dict: + incumbent_ids = [] + for config in self.incumbents: + try: + incumbent_ids.append(self.runhistory.get_config_id(config)) + except KeyError: + incumbent_ids.append(-1) #Should not happen, but occurs sometimes with small-budget runs + logger.warning(f"{config} does not exist in runhistory, but is part of the incumbent!") + + data = { + "incumbent_ids": incumbent_ids, + "rejected_config_ids": self._rejected_config_ids, + "incumbents_changed": self._incumbents_changed, + "trajectory": [dataclasses.asdict(item) for item in self._trajectory], + "state": self.get_state(), + } + + return data + def save(self, filename: str | Path) -> None: """Saves the current state of the intensifier. In addition to the state (retrieved by ``get_state``), this method also saves the incumbents and trajectory. @@ -653,13 +953,7 @@ def save(self, filename: str | Path) -> None: assert str(filename).endswith(".json") filename.parent.mkdir(parents=True, exist_ok=True) - data = { - "incumbent_ids": [self.runhistory.get_config_id(config) for config in self._incumbents], - "rejected_config_ids": self._rejected_config_ids, - "incumbents_changed": self._incumbents_changed, - "trajectory": [dataclasses.asdict(item) for item in self._trajectory], - "state": self.get_state(), - } + data = self.get_save_data() with open(filename, "w") as fp: json.dump(data, fp, indent=2) @@ -683,7 +977,7 @@ def load(self, filename: str | Path) -> None: if self._runhistory is not None: self.runhistory = self._runhistory - self._incumbents = [self.runhistory.get_config(config_id) for config_id in data["incumbent_ids"]] + self.incumbents = [self.runhistory.get_config(config_id) for config_id in data["incumbent_ids"]] self._incumbents_changed = data["incumbents_changed"] self._rejected_config_ids = data["rejected_config_ids"] self._trajectory = [TrajectoryItem(**item) for item in data["trajectory"]] @@ -694,7 +988,7 @@ def _update_trajectory(self, configs: list[Configuration]) -> None: config_ids = [rh.get_config_id(c) for c in configs] costs = [rh.average_cost(c, normalize=False) for c in configs] - self._incumbents = configs + self.incumbents = configs self._incumbents_changed += 1 self._trajectory.append( TrajectoryItem( diff --git a/smac/intensifier/intensifier.py b/smac/intensifier/intensifier.py index 345013f878..44626cd2e3 100644 --- a/smac/intensifier/intensifier.py +++ b/smac/intensifier/intensifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from typing import Any, Iterator from ConfigSpace import Configuration @@ -51,10 +52,15 @@ def __init__( max_config_calls: int = 3, max_incumbents: int = 10, retries: int = 16, + min_config_calls: int = 1, seed: int | None = None, ): super().__init__(scenario=scenario, max_config_calls=max_config_calls, max_incumbents=max_incumbents, seed=seed) self._retries = retries + self._min_config_calls = min_config_calls + + if max_config_calls < min_config_calls: + raise ValueError("min_config_calls must be smaller or equal than max_config_calls") def reset(self) -> None: """Resets the internal variables of the intensifier including the queue.""" @@ -100,6 +106,7 @@ def __iter__(self) -> Iterator[TrialInfo]: queue. - If all incumbents are evaluated on the same trials, a new trial is added to one of the incumbents. - Only challengers which are not rejected/running/incumbent are intensified by N*2. + - If the intensifier cannot find any new trials for n _retries, exit Returns ------- @@ -111,6 +118,12 @@ def __iter__(self) -> Iterator[TrialInfo]: rh = self.runhistory assert self._max_config_calls is not None + is_keys = self.get_instance_seed_keys_of_interest() + if len(is_keys) < self._min_config_calls: + logger.debug(f"There are less instance, seed pairs of interest than the requested minimum trails per " + f"configuration. Changing min_config_calls from {self._min_config_calls} to {len(is_keys)}") + self._min_config_calls = len(is_keys) + # What if there are already trials in the runhistory? Should we queue them up? # Because they are part of the runhistory, they might be selected as incumbents. However, they are not # intensified because they are not part of the queue. We could add them here to incorporate them in the @@ -122,7 +135,7 @@ def __iter__(self) -> Iterator[TrialInfo]: if len(self._queue) == 0: for config in rh.get_configs(): hash = get_config_hash(config) - self._queue.append((config, 1)) + self._queue.append((config, self._min_config_calls)) logger.info(f"Added config {hash} from runhistory to the intensifier queue.") fails = -1 @@ -142,7 +155,7 @@ def __iter__(self) -> Iterator[TrialInfo]: # Also, incorporate ``get_incumbent_instance_seed_budget_keys`` here because challengers are only allowed to # sample from the incumbent's instances incumbents = self.get_incumbents(sort_by="num_trials") - incumbent_isb_keys = self.get_incumbent_instance_seed_budget_keys() + incumbent_isb_keys = self.get_incumbent_instance_seed_budget_keys() # Intersection # Check if configs in queue are still running all_configs_running = True @@ -151,7 +164,7 @@ def __iter__(self) -> Iterator[TrialInfo]: all_configs_running = False break - if len(self._queue) == 0 or all_configs_running: + if len(self._queue) == 0 or all_configs_running: # Incumbents if len(self._queue) == 0: logger.debug("Queue is empty:") else: @@ -208,9 +221,11 @@ def __iter__(self) -> Iterator[TrialInfo]: f"{self._max_config_calls} from incumbent {incumbent_hash}..." ) yield trials[0] + logger.debug(f"--- Finished yielding for config {incumbent_hash}.") # We break here because we only want to intensify one more trial of one incumbent + # TODO intensify until the incumbents are all of equal size (N+1 of biggest incumbent) break else: # assert len(incumbent_isb_keys) == self._max_config_calls @@ -225,7 +240,7 @@ def __iter__(self) -> Iterator[TrialInfo]: try: config = next(self.config_generator) config_hash = get_config_hash(config) - self._queue.append((config, 1)) + self._queue.append((config, self._min_config_calls)) logger.debug(f"--- Added a new config {config_hash} to the queue.") # If we added a new config, then we did something in this iteration @@ -255,46 +270,90 @@ def __iter__(self) -> Iterator[TrialInfo]: self._queue.remove((config, N)) continue + logger.debug(f"--- Config {config_hash} origin ({config.origin})") + # And then we yield as many trials as we specified N # However, only the same instances as the incumbents are used isk_keys: list[InstanceSeedBudgetKey] | None = None if len(incumbent_isb_keys) > 0: isk_keys = incumbent_isb_keys - # TODO: What to do if there are no incumbent instances? (Use-case: call multiple asks) trials = self._get_next_trials(config, N=N, from_keys=isk_keys) - logger.debug(f"--- Yielding {len(trials)} trials to evaluate config {config_hash}...") - for trial in trials: - fails = -1 - yield trial - - logger.debug(f"--- Finished yielding for config {config_hash}.") - - # Now we have to remove the config - self._queue.remove((config, N)) - logger.debug(f"--- Removed config {config_hash} with N={N} from queue.") - - # Finally, we add the same config to the queue with a higher N - # If the config was rejected by the runhistory, then it's been removed in the next iteration - if N < self._max_config_calls: - new_pair = (config, N * 2) - if new_pair not in self._queue: - logger.debug( - f"--- Doubled trials of config {config_hash} to N={N*2} and added it to the queue " - "again." - ) - self._queue.append((config, N * 2)) - - # Also reset fails here + if len(trials) == 0: + # We remove the config and do not add it back to the queue. + self._queue.remove((config, N)) + logger.debug(f"--- No trails to evaluate for config {config_hash}. " + f"Removed config {config_hash} with N={N} from queue.") + else: + logger.debug(f"--- Yielding {len(trials)} trials to evaluate config {config_hash}...") + for trial in trials: + # We need to check if the configuration has been rejected! + if config in self.get_rejected_configs(): + logger.debug(f"--- {config_hash} was rejected so we do not run any more trials") + break fails = -1 - else: - logger.debug(f"--- Config {config_hash} with N={N*2} is already in the queue.") + yield trial + + logger.debug(f"--- Finished yielding for config {config_hash}.") + + # Now we have to remove the config + self._queue.remove((config, N)) + logger.debug(f"--- Removed config {config_hash} with N={N} from queue.") + + + # Finally, we add the same config to the queue with a higher N + # If the config was rejected by the runhistory, then it's been removed in the next iteration + if N < self._max_config_calls and config not in self.get_rejected_configs(): + new_pair = (config, N * 2) + if new_pair not in self._queue: + logger.debug( + f"--- Doubled trials of config {config_hash} to N={N*2} and added it to the queue " + "again." + ) + self._queue.append((config, N * 2)) + + # Also reset fails here + fails = -1 + else: + logger.debug(f"--- Config {config_hash} with N={N*2} is already in the queue.") # If we are at this point, it really is important to break because otherwise, we would intensify # all configs in the queue in one iteration break + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """ + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + config_isb_keys = self.get_instance_seed_budget_keys(config) + config_id = self.runhistory.get_config_id(config) + config_hash = get_config_hash(config) + + # Do not compare very early in the process + if len(config_isb_keys) < 4: + return False + + # Find N in _queue + N = None + for c, cn in self._queue: + if config == c: + N = cn + break + + if N is None: + logger.debug(f"This should not happen, but config {config_hash} is not in the queue.") + return False + + return len(config_isb_keys) == N + def _get_next_trials( self, config: Configuration, diff --git a/smac/intensifier/mixins/__init__.py b/smac/intensifier/mixins/__init__.py new file mode 100644 index 0000000000..7b024699ae --- /dev/null +++ b/smac/intensifier/mixins/__init__.py @@ -0,0 +1,3 @@ +""" +Mixin are used to overwrite single functions in the intensifier classes +""" \ No newline at end of file diff --git a/smac/intensifier/mixins/intermediate_decision.py b/smac/intensifier/mixins/intermediate_decision.py new file mode 100644 index 0000000000..7e90e37137 --- /dev/null +++ b/smac/intensifier/mixins/intermediate_decision.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import copy +import itertools +from abc import abstractmethod +from typing import Any, Callable, Iterator +from scipy.stats import binom + +import dataclasses +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +from ConfigSpace import Configuration + +import smac +from smac.callback import Callback +from smac.constants import MAXINT +from smac.main.config_selector import ConfigSelector +from smac.runhistory import TrialInfo +from smac.runhistory.dataclasses import ( + InstanceSeedBudgetKey, + InstanceSeedKey, + TrajectoryItem, + TrialValue, +) +from smac.runhistory.runhistory import RunHistory +from smac.scenario import Scenario +from smac.utils.configspace import get_config_hash, print_config_changes +from smac.utils.logging import get_logger +from smac.utils.pareto_front import calculate_pareto_front, sort_by_crowding_distance, _get_costs + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + + +def _dominates(a, b) -> bool: + # Checks if a dominates b + a = np.atleast_1d(a) + b = np.atleast_1d(b) + return np.count_nonzero(a <= b) >= len(a) and np.count_nonzero(a < b) >= 1 + +class NewCostDominatesOldCost(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """ + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + config_isb_keys = self.get_instance_seed_budget_keys(config) + + if not hasattr(self, "_old_config_cost"): + self._old_config_cost = {} # TODO remove configuration when done + + new_cost = self.runhistory.average_cost(config, config_isb_keys) + if config not in self._old_config_cost: + self._old_config_cost[config] = new_cost + return True + + old_cost = self._old_config_cost[config] + if _dominates(new_cost, old_cost): + self._old_config_cost[config] = new_cost + return True + return False + +class NewCostDominatesOldCostSkipFirst(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """ Do the first comparison with the incumbent when the configuration dominates the cost after finishing its first trial + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + config_isb_keys = self.get_instance_seed_budget_keys(config) + + if not hasattr(self, "_old_config_cost"): + self._old_config_cost = {} # TODO remove configuration when done + + new_cost = self.runhistory.average_cost(config, config_isb_keys) + if config not in self._old_config_cost: + self._old_config_cost[config] = new_cost + return False + + old_cost = self._old_config_cost[config] + if _dominates(new_cost, old_cost): + self._old_config_cost[config] = new_cost + return True + return False + +class DoublingNComparison(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """ + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + config_isb_keys = self.get_instance_seed_budget_keys(config) + config_id = self.runhistory.get_config_id(config) + config_hash = get_config_hash(config) + + # max_trigger_number = int(np.ceil(np.log2(self._max_config_calls))) + # trigger_points = [(2**n) - 1 for n in range(1, max_trigger_number + 1)] # 1, 3, 7, 15, ... + # logger.debug(f"{trigger_points=}") + # logger.debug(f"{len(config_isb_keys)=}") + # return len(config_isb_keys) in trigger_points + + nkeys = len(config_isb_keys) + return (nkeys+1) & nkeys == 0 # checks if nkeys+1 is a power of 2 (complies with the sequence (2**n)-1) + + +class Always(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + return True + + +class Never(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + return False + + +class DoublingNComparisonFour(): + + def _check_for_intermediate_comparison(self, config: Configuration) -> bool: + """ + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which decides if the current configuration should be compared against the incumbent. + """ + config_isb_keys = self.get_instance_seed_budget_keys(config) + config_id = self.runhistory.get_config_id(config) + config_hash = get_config_hash(config) + + max_trigger_number = int(np.ceil(np.log2(self._max_config_calls))) + trigger_points = [(2 ** n) - 1 for n in range(2, max_trigger_number + 1)] # 1, 3, 7, 15, ... + logger.debug(f"{trigger_points=}") + logger.debug(f"{len(config_isb_keys)=}") + return len(config_isb_keys) in trigger_points diff --git a/smac/intensifier/mixins/intermediate_update.py b/smac/intensifier/mixins/intermediate_update.py new file mode 100644 index 0000000000..413f83c696 --- /dev/null +++ b/smac/intensifier/mixins/intermediate_update.py @@ -0,0 +1,589 @@ +from __future__ import annotations + +import copy +import itertools +from abc import abstractmethod +from typing import Any, Callable, Iterator +from scipy.stats import binom + +import dataclasses +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +from ConfigSpace import Configuration + +import smac +from smac.callback import Callback +from smac.constants import MAXINT +from smac.main.config_selector import ConfigSelector +from smac.runhistory import TrialInfo +from smac.runhistory.dataclasses import ( + InstanceSeedBudgetKey, + InstanceSeedKey, + TrajectoryItem, + TrialValue, +) +from smac.runhistory.runhistory import RunHistory +from smac.scenario import Scenario +from smac.utils.configspace import get_config_hash, print_config_changes +from smac.utils.logging import get_logger +from smac.utils.pareto_front import calculate_pareto_front, sort_by_crowding_distance, _get_costs + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + + +class DebugComparison(object): + + def _register_comparison(self, **kwargs): + logger.debug(f"Made intermediate comparison with {kwargs['name']} comparison ") + if not hasattr(self, "_intermediate_comparisons_log"): + self._intermediate_comparisons_log = [] + self._intermediate_comparisons_log.append(kwargs) + + def _get_costs_comp(self, config: Configuration) -> dict: + incumbents = self.get_incumbents() + if config not in incumbents: + incumbents.append(config) + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + all_incumbent_isb_keys = [config_isb_keys for _ in incumbents] + costs = _get_costs(self.runhistory, incumbents, all_incumbent_isb_keys) + + return {conf: cost for conf, cost in zip(incumbents, costs)} + + +class FullIncumbentComparison(DebugComparison): + + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + logger.debug( + f"Perform intermediate comparions of config {config_hash} with incumbents to see if it is worse" + ) + # TODO perform comparison with incumbent on current instances. + # Check if the config with these number of trials is part of the Pareto front + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + # Only compare domination between one incumbent (as relaxation measure) + if config not in incumbents: + incumbents.append(config) + + # Only the trials of the challenger + all_incumbent_isb_keys = [config_isb_keys for _ in incumbents] + + new_incumbents = self._calculate_pareto_front(self.runhistory, incumbents, + all_incumbent_isb_keys) + + verdict = config in new_incumbents + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="FullInc") + + return config in new_incumbents + + +class SingleIncumbentComparison(DebugComparison): + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + logger.debug( + f"Perform intermediate comparions of config {config_hash} with incumbents to see if it is worse" + ) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + # Only compare domination between one incumbent (as relaxation measure) + iid = self._rng.choice(len(incumbents)) + incumbents = [incumbents[iid], config] + + # Only the trials of the challenger + all_incumbent_isb_keys = [config_isb_keys for _ in incumbents] + + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdict = config in new_incumbents + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="SingleInc") + + return config in new_incumbents + + +class ClosestIncumbentComparison(DebugComparison): + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + logger.debug( + f"Perform intermediate comparisons of config {config_hash} with incumbents to see if it is worse" + ) + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Only compare domination between one incumbent (as relaxation measure) + #iid = self._rng.choice(len(incumbents)) + #TODO Normalize to determine closests? + inc_costs = _get_costs(self.runhistory, incumbents, [config_isb_keys for _ in incumbents], normalize=True) + conf_cost = _get_costs(self.runhistory, [config], [config_isb_keys], normalize=True)[0] + distances = [np.linalg.norm(inc_cost - conf_cost) for inc_cost in inc_costs] + iid = np.argmin(distances) + incumbents = [incumbents[iid], config] + + # Only the trials of the challenger + all_incumbent_isb_keys = [config_isb_keys for _ in incumbents] + + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdict = config in new_incumbents + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="ClosestInc") + + return config in new_incumbents + + +class RandomComparison(DebugComparison): + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + verdict = self._rng.random() >= 0.5 + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="Random") + return verdict + + +class NoComparison(DebugComparison): + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration against the incumbent + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + verdict = True + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="NoComp") + return verdict + + +class BootstrapComparison(DebugComparison): + + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration by generating bootstraps + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + if config not in incumbents: + incumbents.append(config) + + n_samples = 1000 + if len(config_isb_keys) < 7: # When there are only a limited number of trials available we run all combinations + samples = list(itertools.combinations_with_replacement(list(range(len(config_isb_keys))), r=len(config_isb_keys))) + n_samples = len(samples) + else: + samples = np.random.choice(len(config_isb_keys), + (n_samples, len(config_isb_keys)), + replace=True) + + verdicts = np.zeros(n_samples, dtype=bool) + + + for sid, sample in enumerate(samples): + sample_isb_keys = [config_isb_keys[i] for i in sample] + all_incumbent_isb_keys = [sample_isb_keys]*len(incumbents) + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdicts[sid] = config in new_incumbents + + verdict = np.count_nonzero(verdicts) >= 0.5 * n_samples # The config is in more than 50% of the times non-dominated + #P = np.count_nonzero(verdicts)/n_samples + #print(f"P = {np.count_nonzero(verdicts)}/{n_samples}={P:.2f}") + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="Bootstrap", + probability=np.count_nonzero(verdicts)/n_samples, + n_samples=n_samples) + return verdict + + +class BootstrapSingleComparison(DebugComparison): + + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration by generating bootstraps + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + iid = self._rng.choice(len(incumbents)) + incumbents = [incumbents[iid], config] + + n_samples = 1000 + if len(config_isb_keys) < 7: # When there are only a limited number of trials available we run all combinations + samples = list(itertools.combinations_with_replacement(list(range(len(config_isb_keys))), r=len(config_isb_keys))) + n_samples = len(samples) + else: + samples = np.random.choice(len(config_isb_keys), + (n_samples, len(config_isb_keys)), + replace=True) + + verdicts = np.zeros(n_samples, dtype=bool) + + + for sid, sample in enumerate(samples): + sample_isb_keys = [config_isb_keys[i] for i in sample] + all_incumbent_isb_keys = [sample_isb_keys]*len(incumbents) + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdicts[sid] = config in new_incumbents + + verdict = np.count_nonzero(verdicts) >= 0.5 * n_samples # The config is in more than 50% of the times non-dominated + #P = np.count_nonzero(verdicts)/n_samples + #print(f"P = {np.count_nonzero(verdicts)}/{n_samples}={P:.2f}") + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="BootstrapSingle", + probability=np.count_nonzero(verdicts)/n_samples, + n_samples=n_samples) + return verdict + + +class BootstrapClosestComparison(DebugComparison): + + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration by generating bootstraps + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + inc_costs = _get_costs(self.runhistory, incumbents, [config_isb_keys for _ in incumbents], normalize=True) + conf_cost = _get_costs(self.runhistory, [config], [config_isb_keys], normalize=True)[0] + distances = [np.linalg.norm(inc_cost - conf_cost) for inc_cost in inc_costs] + iid = np.argmin(distances) + incumbents = [incumbents[iid], config] + + n_samples = 1000 + if len(config_isb_keys) < 7: # When there are only a limited number of trials available we run all combinations + samples = list(itertools.combinations_with_replacement(list(range(len(config_isb_keys))), r=len(config_isb_keys))) + n_samples = len(samples) + else: + samples = np.random.choice(len(config_isb_keys), + (n_samples, len(config_isb_keys)), + replace=True) + + verdicts = np.zeros(n_samples, dtype=bool) + + + for sid, sample in enumerate(samples): + sample_isb_keys = [config_isb_keys[i] for i in sample] + all_incumbent_isb_keys = [sample_isb_keys]*len(incumbents) + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdicts[sid] = config in new_incumbents + + verdict = np.count_nonzero(verdicts) >= 0.5 * n_samples # The config is in more than 50% of the times non-dominated + #P = np.count_nonzero(verdicts)/n_samples + #print(f"P = {np.count_nonzero(verdicts)}/{n_samples}={P:.2f}") + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs=self._get_costs_comp(config), + prediction=verdict, + name="BootstrapClosest", + probability=np.count_nonzero(verdicts)/n_samples, + n_samples=n_samples) + return verdict + + +class SRaceComparison(DebugComparison): + def _intermediate_comparison(self, config: Configuration) -> bool: + """Compares the configuration by generating bootstraps + + Parameters + ---------- + config: Configuration + + Returns + ------- + A boolean which indicates if we should continue with this configuration. + """ + + def get_alpha(delta, n_instances): + steps = 0 + n = 1 + inst = 0 + while inst < n_instances: + steps += 1 + inst += n + n *= 2 + + return (1 - delta) / (n_instances) * (steps - 1) + + def dominates(a, b): + # Checks if a dominates b + a = np.array(a) + b = np.array(b) + return 1 if np.count_nonzero(a <= b) >= len(a) and np.count_nonzero(a < b) >= 1 else 0 + + config_hash = get_config_hash(config) + incumbents = self.get_incumbents() + config_isb_keys = self.get_instance_seed_budget_keys(config, compare=True) + incumbent_isb_comparison_keys = self.get_incumbent_instance_seed_budget_keys( + compare=True) + + # Check if the incumbents ran on all the ones of this config + if not all([key in incumbent_isb_comparison_keys for key in config_isb_keys]): + logger.debug( + "Config ran on other isb_keys than the incumbents. Should not happen.") + return True + + # Ensure that the config is not part of the incumbent + if config in incumbents: + return True + + p_values = [] + chall_perf = self.runhistory._cost(config, config_isb_keys) + for incumbent in incumbents: + inc_perf = self.runhistory._cost(incumbent, config_isb_keys) + n_ij = sum([dominates(*x) for x in zip(chall_perf, inc_perf)]) # Number of times the incumbent candidate dominates the challenger + n_ji = sum([dominates(*x) for x in zip(inc_perf, chall_perf)]) # Number of times the challenger dominates the incumbent candidate + p_value = 1 - binom.cdf(n_ij - 1, n_ij + n_ji, .5) + p_values.append(p_value) + + pvalues_order = np.argsort(p_values) + + # Holm-Bonferroni + reject = np.zeros(len(p_values), dtype=bool) # Do not reject any test by default + alpha = get_alpha(0.05, len(config_isb_keys)) + for i, index in enumerate(pvalues_order): + corrected_alpha = alpha / (len(p_values) - i) # Holm-Bonferroni + if pvalues_order[index] < corrected_alpha: + # Reject H0 -> winner > candidate + reject[index] = True + else: + break + + verdict = np.count_nonzero(reject) != 0 + #P = np.count_nonzero(verdicts)/n_samples + #print(f"P = {np.count_nonzero(verdicts)}/{n_samples}={P:.2f}") + self._register_comparison(config=config, + incumbent=self.get_incumbents(), + isb_keys=len(config_isb_keys), + costs={conf: cost for conf, cost in zip(incumbents, costs)}, + prediction=verdict, + name="S-Race") + return verdict \ No newline at end of file diff --git a/smac/intensifier/mixins/update_incumbent.py b/smac/intensifier/mixins/update_incumbent.py new file mode 100644 index 0000000000..8ddb5d8e24 --- /dev/null +++ b/smac/intensifier/mixins/update_incumbent.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import copy +import itertools +from abc import abstractmethod +from typing import Any, Callable, Iterator +from scipy.stats import binom + +import dataclasses +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +from ConfigSpace import Configuration + +import smac +from smac.callback import Callback +from smac.constants import MAXINT +from smac.main.config_selector import ConfigSelector +from smac.runhistory import TrialInfo +from smac.runhistory.dataclasses import ( + InstanceSeedBudgetKey, + InstanceSeedKey, + TrajectoryItem, + TrialValue, +) +from smac.runhistory.runhistory import RunHistory +from smac.scenario import Scenario +from smac.utils.configspace import get_config_hash, print_config_changes +from smac.utils.logging import get_logger +from smac.utils.pareto_front import calculate_pareto_front, sort_by_crowding_distance, _get_costs + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + + +class DebugUpdate(object): + def _register_incumbent_update(self, **kwargs): + if not hasattr(self, "_update_incumbent_log"): + self._update_incumbent_log = [] + self._update_incumbent_log.append(kwargs) + +class NonDominatedUpdate(DebugUpdate): + + def _update_incumbent(self, config: Configuration) -> list[Configuration]: + """Updates the incumbent with the config (which can be the challenger) + + Parameters + ---------- + config: Configuration + + Returns + ------- + """ + rh = self.runhistory + + incumbents = self.get_incumbents() + + if config not in incumbents: + incumbents.append(config) + + isb_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) + all_incumbent_isb_keys = [isb_keys for _ in range(len(incumbents))] + + # We compare the incumbents now and only return the ones on the Pareto front + # _calculate_pareto_front returns only non-dominated points + new_incumbents = self._calculate_pareto_front(rh, incumbents, + all_incumbent_isb_keys) + + self._register_incumbent_update(config=config, + incumbent=self.get_incumbents(), + isb_keys=isb_keys, + new_incumbents=new_incumbents, + name="NonDominated",) + + return new_incumbents +class BootstrapUpdate(DebugUpdate): + + def _update_incumbent(self, config: Configuration) -> list[Configuration]: + """Updates the incumbent with the config (which can be the challenger) + + Parameters + ---------- + config: Configuration + + Returns + ------- + """ + rh = self.runhistory + + incumbents = self.get_incumbents() + + if config not in incumbents: + incumbents.append(config) + + isb_keys = self.get_incumbent_instance_seed_budget_keys(compare=True) + + n_samples = 1000 + if len(isb_keys) < 7: # When there are only a limited number of trials available we run all combinations + samples = list(itertools.combinations_with_replacement(list(range(len(isb_keys))), r=len(isb_keys))) + n_samples = len(samples) + else: + samples = np.random.choice(len(isb_keys), (n_samples, len(isb_keys)), replace=True) + + verdicts = np.zeros((n_samples, len(incumbents)), dtype=bool) + + for sid, sample in enumerate(samples): + sample_isb_keys = [isb_keys[i] for i in sample] + all_incumbent_isb_keys = [sample_isb_keys] * len(incumbents) + new_incumbents = self._calculate_pareto_front(self.runhistory, + incumbents, + all_incumbent_isb_keys) + + verdicts[sid, :] = [incumbents[i] in new_incumbents for i in range(len(incumbents))] + + probabilities = np.count_nonzero(verdicts, axis=0) / n_samples + + new_incumbent_ids = np.argwhere(probabilities >= 0.5).flatten() # Incumbent needs to be non-dominated at least 50% of the time + new_incumbents = [incumbents[i] for i in new_incumbent_ids] + + self._register_incumbent_update(config=config, + incumbent=self.get_incumbents(), + isb_keys=isb_keys, + new_incumbents=new_incumbents, + name="Bootstrap", + probabilities=probabilities, + n_samples=n_samples,) + + return new_incumbents \ No newline at end of file diff --git a/smac/intensifier/multi_objective_intensifier.py b/smac/intensifier/multi_objective_intensifier.py new file mode 100644 index 0000000000..8ecf7ea538 --- /dev/null +++ b/smac/intensifier/multi_objective_intensifier.py @@ -0,0 +1,91 @@ +# TODO does this work for multi-fidelity? +# Yes, then pass a pareto front calculation function to the abstract intensifier instead of subclassing it + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Callable, Iterator + +import dataclasses +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +from ConfigSpace import Configuration + +import smac +from smac.callback import Callback +from smac.constants import MAXINT +from smac.main.config_selector import ConfigSelector +from smac.runhistory import TrialInfo +from smac.runhistory.dataclasses import ( + InstanceSeedBudgetKey, + InstanceSeedKey, + TrajectoryItem, + TrialValue, +) +from smac.runhistory.runhistory import RunHistory +from smac.scenario import Scenario +from smac.utils.configspace import get_config_hash, print_config_changes +from smac.utils.logging import get_logger +from smac.utils.pareto_front import calculate_pareto_front, sort_by_crowding_distance +from smac.intensifier.abstract_intensifier import AbstractIntensifier +from smac.intensifier.hyperband import Hyperband +from smac.intensifier.successive_halving import SuccessiveHalving +from smac.intensifier.intensifier import Intensifier + + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +logger = get_logger(__name__) + +# TODO add minimum population size? + +class MOIntensifierMixin(object): + def _calculate_pareto_front( + self, + runhistory: RunHistory, + configs: list[Configuration], + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], + ) -> list[Configuration]: + return calculate_pareto_front( + runhistory=runhistory, + configs=configs, + config_instance_seed_budget_keys=config_instance_seed_budget_keys, + ) + + # def _remove_incumbent(self, config: Configuration, previous_incumbent_ids: list[int], new_incumbent_ids: list[int]) -> None: + # # TODO adjust + # raise NotImplementedError + + def _cut_incumbents(self, incumbent_ids: list[int], all_incumbent_isb_keys: list[list[InstanceSeedBudgetKey]]) -> list[int]: + #TODO JG sort by hypervolume + new_incumbents = sort_by_crowding_distance(self.runhistory, incumbent_ids, all_incumbent_isb_keys) + new_incumbents = new_incumbents[: self._max_incumbents] + + logger.info( + f"Removed one incumbent using their reduction in hypervolume because more than {self._max_incumbents} are " + "available." + ) + + return new_incumbents + + def get_instance_seed_budget_keys( + self, config: Configuration, compare: bool = False + ) -> list[InstanceSeedBudgetKey]: + """Returns the instance-seed-budget keys for a given configuration. This method is *used for + updating the incumbents* and might differ for different intensifiers. For example, if incumbents should only + be compared on the highest observed budgets. + """ + return self.runhistory.get_instance_seed_budget_keys(config, highest_observed_budget_only=True) + +class MOIntensifier(MOIntensifierMixin, Intensifier): + pass + +class MOSuccessiveHalving(MOIntensifierMixin, SuccessiveHalving): + pass + +class MOHyperband(MOIntensifierMixin, Hyperband): + pass \ No newline at end of file diff --git a/smac/main/config_selector.py b/smac/main/config_selector.py index 7b2053eee2..5f46834a78 100644 --- a/smac/main/config_selector.py +++ b/smac/main/config_selector.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from typing import Any, Iterator import copy @@ -35,8 +36,13 @@ class ConfigSelector: Parameters ---------- - retrain_after : int, defaults to 8 + retrain_after : int | None, defaults to 8 How many configurations should be returned before the surrogate model is retrained. + retrain_wallclock_ratio: float | None, default to None + How much time of the total elapsed wallclock time should be spend on retraining the surrogate model + and the acquisition function look. Example ratio of 0.1 would result in that only 10% of the wallclock time is spend on retraining. + min_configurations: int, defaults to 2 + The minimum number of configurations that need to yield before retraining can occur. Should be lower or equal to retrain_after. retries : int, defaults to 8 How often to retry receiving a new configuration before giving up. min_trials: int, defaults to 1 @@ -50,7 +56,9 @@ def __init__( self, scenario: Scenario, *, - retrain_after: int = 8, + retrain_after: int | None = 8, + retrain_wallclock_ratio: float | None = None, + min_configurations: int = 2, retries: int = 16, min_trials: int = 1, ) -> None: @@ -70,6 +78,8 @@ def __init__( # And other variables self._retrain_after = retrain_after + self._retrain_wallclock_ratio = retrain_wallclock_ratio + self._min_configurations = min_configurations self._previous_entries = -1 self._predict_x_best = True self._min_trials = min_trials @@ -78,10 +88,22 @@ def __init__( # How often to retry receiving a new configuration # (counter increases if the received config was already returned before) self._retries = retries + self._counter = 0 + + self._wallclock_start_time: float = time.time() + self._acquisition_training_times: list[float] = [] # Processed configurations should be stored here; this is important to not return the same configuration twice self._processed_configs: list[Configuration] = [] + #Check if there is at least one retrain condition + if self._retrain_after is None and self._retrain_wallclock_ratio is None: + raise ValueError("No retrain condition specified!") + + if self._retrain_after is not None: + if self._retrain_after < self._min_configurations: + raise ValueError("retrain_after should be higher or equal to min_configurations") + def _set_components( self, initial_design: AbstractInitialDesign, @@ -105,6 +127,12 @@ def _set_components( if len(self._initial_design_configs) == 0: raise RuntimeError("SMAC needs initial configurations to work.") + if hasattr(self._acquisition_function, "runhistory"): + self._acquisition_function.runhistory = runhistory + + if hasattr(self._acquisition_function, "runhistory_encoder"): + self._acquisition_function.runhistory_encoder = runhistory_encoder + @property def meta(self) -> dict[str, Any]: """Returns the meta data of the created object.""" @@ -126,7 +154,7 @@ def __iter__(self) -> Iterator[Configuration]: Note ---- When SMAC continues a run, processed configurations from the runhistory are ignored. For example, if the - intitial design configurations already have been processed, they are ignored here. After the run is + initial design configurations already have been processed, they are ignored here. After the run is continued, however, the surrogate model is trained based on the runhistory in all cases. Returns @@ -186,6 +214,7 @@ def __iter__(self) -> Iterator[Configuration]: continue # Check if X/Y differs from the last run, otherwise use cached results + start_time = time.time() if self._previous_entries != Y.shape[0]: self._model.train(X, Y) @@ -204,6 +233,7 @@ def __iter__(self) -> Iterator[Configuration]: incumbent_array=x_best_array, num_data=len(self._get_evaluated_configs()), X=X_configurations, + incumbents=self._runhistory.incumbents, ) # We want to cache how many entries we used because if we have the same number of entries @@ -217,22 +247,25 @@ def __iter__(self) -> Iterator[Configuration]: random_design=self._random_design, ) - counter = 0 + if self._retrain_wallclock_ratio is not None: + len(challengers) # TODO hacky: Forces actual computation of the acquisition function maximizer + + self._acquisition_training_times.append(time.time() - start_time) + + failed_counter = 0 for config in challengers: if config not in self._processed_configs: - counter += 1 + self._counter += 1 self._processed_configs.append(config) self._call_callbacks_on_end(config) yield config - retrain = counter == self._retrain_after + retrain = self._check_for_retrain() self._call_callbacks_on_start() # We break to enforce a new iteration of the while loop (i.e. we retrain the surrogate model) if retrain: - logger.debug( - f"Yielded {counter} configurations. Start new iteration and retrain surrogate model." - ) + self._counter = 0 break else: failed_counter += 1 @@ -242,6 +275,35 @@ def __iter__(self) -> Iterator[Configuration]: logger.warning(f"Could not return a new configuration after {self._retries} retries." "") return + def _check_for_retrain(self) -> bool: + if self._retrain_after is not None: + if self._counter >= self._retrain_after: + logger.debug( + f"Yielded {self._counter} configurations. Start new iteration and retrain surrogate model." + ) + return True + + if self._retrain_wallclock_ratio is not None: + if self._counter < self._min_configurations: + return False + + # Total elapsed wallcock time + elapsed_time = time.time() - self._wallclock_start_time + + # Total time spend on getting configurations with the surrogate model + acquisition_training_time = sum(self._acquisition_training_times) + + # Retrain when more time has been spend + if acquisition_training_time / elapsed_time < self._retrain_wallclock_ratio: + logger.debug( + f"Less than {self._retrain_wallclock_ratio:.2%} ({acquisition_training_time / elapsed_time:.2f}) " + f"of the elapsed wallclock time ({elapsed_time:.2f}s) has been spend on finding new configurations " + f"with the surrogate model. Start new iteration and retrain surrogate model." + ) + return True + + return False + def _call_callbacks_on_start(self) -> None: for callback in self._callbacks: callback.on_next_configurations_start(self) diff --git a/smac/model/multi_objective_model.py b/smac/model/multi_objective_model.py index 23a765a8ff..e2baca3822 100644 --- a/smac/model/multi_objective_model.py +++ b/smac/model/multi_objective_model.py @@ -53,6 +53,8 @@ def __init__( seed=seed, ) + self._n_features = self._models[0]._n_features #TODO JG make more elegant + @property def models(self) -> list[AbstractModel]: """The internally used surrogate models.""" diff --git a/smac/multi_objective/aggregation_strategy.py b/smac/multi_objective/aggregation_strategy.py index bac0a00972..e139d52650 100644 --- a/smac/multi_objective/aggregation_strategy.py +++ b/smac/multi_objective/aggregation_strategy.py @@ -42,3 +42,25 @@ def meta(self) -> dict[str, Any]: def __call__(self, values: list[float]) -> float: # noqa: D102 return float(np.average(values, axis=0, weights=self._objective_weights)) + + +class NoAggregationStrategy(AbstractMultiObjectiveAlgorithm): + """ + A class to not aggregate multi-objective losses into a single objective losses. + """ + + def __call__(self, values: list[float]) -> list[float]: + """ + Not transform a multi-objective loss to a single loss. + + Parameters + ---------- + values : list[float] + Normalized cost values. + + Returns + ------- + costs : list[float] + costs. + """ + return values diff --git a/smac/runhistory/encoder/abstract_encoder.py b/smac/runhistory/encoder/abstract_encoder.py index 42454771ed..3f4f3a2fa7 100644 --- a/smac/runhistory/encoder/abstract_encoder.py +++ b/smac/runhistory/encoder/abstract_encoder.py @@ -32,6 +32,7 @@ class AbstractRunHistoryEncoder: scale_percentage : int, defaults to 5 Scaled y-transformation use a percentile to estimate distance to optimum. Only used in some sub-classes. seed : int | None, defaults to none + native_multi_objective: bool, defaults to False Raises ------ @@ -50,6 +51,8 @@ def __init__( lower_budget_states: list[StatusType] = [], scale_percentage: int = 5, seed: int | None = None, + native_multi_objective: bool = False, + normalize: bool = True, ) -> None: if considered_states is None: raise TypeError("No success states are given.") @@ -86,6 +89,9 @@ def __init__( self._multi_objective_algorithm: AbstractMultiObjectiveAlgorithm | None = None self._runhistory: RunHistory | None = None + self._native_multi_objective = native_multi_objective + self._normalize = normalize + @property def meta(self) -> dict[str, Any]: """ @@ -299,3 +305,4 @@ def transform_response_values( transformed_values : np.ndarray """ raise NotImplementedError + diff --git a/smac/runhistory/encoder/encoder.py b/smac/runhistory/encoder/encoder.py index 25672a92ff..05ac3fe203 100644 --- a/smac/runhistory/encoder/encoder.py +++ b/smac/runhistory/encoder/encoder.py @@ -18,6 +18,7 @@ class RunHistoryEncoder(AbstractRunHistoryEncoder): + def _build_matrix( self, trials: Mapping[TrialKey, TrialValue], @@ -29,8 +30,11 @@ def _build_matrix( X = np.ones([n_rows, n_cols + self._n_features]) * np.nan # For now we keep it as 1 - # TODO: Extend for native multi-objective - y = np.ones([n_rows, 1]) + # TODO: Extend with checks for native multi-objective (return size of multi_objective_algorithm) + if self._native_multi_objective: + y = np.ones([n_rows, self._n_objectives]) + else: + y = np.ones([n_rows, 1]) # Then populate matrix for row, (key, run) in enumerate(trials.items()): @@ -51,7 +55,7 @@ def _build_matrix( # Let's normalize y here # We use the objective_bounds calculated by the runhistory - y_ = normalize_costs(run.cost, self.runhistory.objective_bounds) + y_ = normalize_costs(run.cost, self.runhistory.objective_bounds) if self._normalize else run.cost y_agg = self._multi_objective_algorithm(y_) y[row] = y_agg else: diff --git a/smac/runhistory/runhistory.py b/smac/runhistory/runhistory.py index aaf88889c8..c7b7adf19e 100644 --- a/smac/runhistory/runhistory.py +++ b/smac/runhistory/runhistory.py @@ -137,6 +137,10 @@ def reset(self) -> None: self._n_objectives: int = -1 self._objective_bounds: list[tuple[float, float]] = [] + # Store incumbents. Gets updated whenever the incumbents in the + # intensifier are updated + self._incumbents: list[Configuration] = [] + def __contains__(self, k: object) -> bool: """Dictionary semantics for `k in runhistory`.""" return k in self._data @@ -157,6 +161,14 @@ def __eq__(self, other: Any) -> bool: """Enables to check equality of runhistory if the run is continued.""" return self._data == other._data + @property + def incumbents(self) -> list[Configuration]: + return self._incumbents + + @incumbents.setter + def incumbents(self, incumbents: list[Configuration]) -> None: + self._incumbents = incumbents + def empty(self) -> bool: """Check whether the RunHistory is empty. @@ -451,6 +463,7 @@ def average_cost( config: Configuration, instance_seed_budget_keys: list[InstanceSeedBudgetKey] | None = None, normalize: bool = False, + run_multi_objective_algorithm: bool = False, ) -> float | list[float]: """Return the average cost of a configuration. This is the mean of costs of all instance- seed pairs. @@ -481,10 +494,11 @@ def average_cost( averaged_costs = np.mean(costs, axis=0).tolist() if normalize: - assert self.multi_objective_algorithm is not None - normalized_costs = normalize_costs(averaged_costs, self._objective_bounds) + averaged_costs = normalize_costs(averaged_costs, self._objective_bounds) - return self.multi_objective_algorithm(normalized_costs) + if run_multi_objective_algorithm: + assert self.multi_objective_algorithm is not None + return self.multi_objective_algorithm(averaged_costs) else: return averaged_costs @@ -594,6 +608,9 @@ def get_config(self, config_id: int) -> Configuration: def get_config_id(self, config: Configuration) -> int: """Returns the configuration id from a configuration.""" + if config not in self._config_ids: + logger.warning("Requested id of unknown configuration!") + return -1 return self._config_ids[config] def has_config(self, config: Configuration) -> bool: diff --git a/smac/runner/aclib_runner.py b/smac/runner/aclib_runner.py new file mode 100644 index 0000000000..d84aedff3c --- /dev/null +++ b/smac/runner/aclib_runner.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + +import re +from abc import ABC, abstractmethod +from typing import Any, Iterator + +import time +import traceback +from subprocess import Popen, PIPE + +import numpy as np +from ConfigSpace import Configuration + +from smac.runhistory import StatusType, TrialInfo, TrialValue +from smac.scenario import Scenario +from smac.utils.logging import get_logger +from smac.runner.target_function_script_runner import TargetFunctionScriptRunner + +logger = get_logger(__name__) + +class ACLibRunner(TargetFunctionScriptRunner): + def __init__(self, + target_function: str, + scenario: Scenario, + required_arguments: list[str] = [], + target_function_arguments: dict[str, str] | None = None, + ): + + self._target_function_arguments = target_function_arguments + + super().__init__(target_function, scenario, required_arguments) + def __call__(self, algorithm_kwargs: dict[str, Any]) -> tuple[str, str]: + # kwargs has "instance", "seed" and "budget" --> translate those + + cmd = self._target_function.split(" ") + if self._target_function_arguments is not None: + for k, v in self._target_function_arguments.items(): + cmd += [f"--{k}={v}"] + + if self._scenario.trial_walltime_limit is not None: + cmd += [f"--cutoff={self._scenario.trial_walltime_limit}"] + + config = ["--config"] + + for k, v in algorithm_kwargs.items(): + v = str(v) + k = str(k) + + # Let's remove some spaces + v = v.replace(" ", "") + + if k in ["instance", "seed"]: + cmd += [f"--{k}={v}"] + elif k == "instance_features": + continue + else: + config += [k, v] + + cmd += config + + logger.debug(f"Calling: {' '.join(cmd)}") + p = Popen(cmd, shell=False, stdout=PIPE, stderr=PIPE, universal_newlines=True) + output, error = p.communicate() + + logger.debug("Stdout: %s" % output) + logger.debug("Stderr: %s" % error) + + result_begin = "Result for SMAC3v2: " + outputline = "" + for line in output.split("\n"): + line = line.strip() + if re.match(result_begin, line): + # print("match") + outputline = line[len(result_begin):] + + logger.debug(f"Found result in output: {outputline}") + + #Parse output to form of key=value;key2=value2;...;cost=value1,value2;... + + return outputline, error + diff --git a/smac/runner/dask_runner.py b/smac/runner/dask_runner.py index b9aade4015..ad1f543710 100644 --- a/smac/runner/dask_runner.py +++ b/smac/runner/dask_runner.py @@ -91,7 +91,7 @@ def __init__( ) if self._scenario.output_directory is not None: - self._scheduler_file = self._scenario.output_directory / ".dask_scheduler_file" + self._scheduler_file = Path(self._scenario.output_directory).joinpath(".dask_scheduler_file") self._client.write_scheduler_file(scheduler_file=str(self._scheduler_file)) else: # We just use their set up diff --git a/smac/runner/target_function_script_runner.py b/smac/runner/target_function_script_runner.py index 17feffc983..cf3f49fce8 100644 --- a/smac/runner/target_function_script_runner.py +++ b/smac/runner/target_function_script_runner.py @@ -40,7 +40,7 @@ class TargetFunctionScriptRunner(AbstractSerialRunner): Parameters ---------- - target_function : Callable + target_function : str The target function. scenario : Scenario required_arguments : list[str] @@ -183,7 +183,7 @@ def run( if "additional_info" in outputs: additional_info["additional_info"] = outputs["additional_info"] - if status != StatusType.SUCCESS: + if not status in [StatusType.SUCCESS, StatusType.TIMEOUT]: additional_info["error"] = error if cost != self._crash_cost: @@ -199,7 +199,7 @@ def __call__( algorithm_kwargs: dict[str, Any], ) -> tuple[str, str]: """Calls the algorithm, which is processed in the ``run`` method.""" - cmd = [self._target_function] + cmd = self._target_function.split(" ") for k, v in algorithm_kwargs.items(): v = str(v) k = str(k) diff --git a/smac/utils/multi_objective.py b/smac/utils/multi_objective.py index f959fe7836..8be783ed08 100644 --- a/smac/utils/multi_objective.py +++ b/smac/utils/multi_objective.py @@ -30,6 +30,10 @@ def normalize_costs( costs = [] for v, b in zip(values, bounds): assert not isinstance(v, list) + + # limit value to bounds region + v = min(max(v, b[0]), b[1]) + p = v - b[0] q = b[1] - b[0] diff --git a/smac/utils/pareto_front.py b/smac/utils/pareto_front.py index 8d2ee6bb4d..4240c46898 100644 --- a/smac/utils/pareto_front.py +++ b/smac/utils/pareto_front.py @@ -11,6 +11,7 @@ def _get_costs( runhistory: RunHistory, configs: list[Configuration], config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], + normalize: bool = False, ) -> np.ndarray: """Returns the costs of the passed configurations. @@ -22,6 +23,8 @@ def _get_costs( The configs for which the costs should be returned. config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]] The instance-seed budget keys for the configs for which the costs should be returned. + normalize: bool + If the costs should be normalised Returns ------- @@ -38,7 +41,7 @@ def _get_costs( # configuration # However, we only want to consider the config trials # Average cost is a list of floats (one for each objective) - average_cost = runhistory.average_cost(config, isb_keys, normalize=False) + average_cost = runhistory.average_cost(config, isb_keys, normalize=normalize, run_multi_objective_algorithm=normalize) average_costs += [average_cost] # Let's work with a numpy array for efficiency @@ -50,7 +53,7 @@ def calculate_pareto_front( configs: list[Configuration], config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], ) -> list[Configuration]: - """Compares the passed configurations and returns only the ones on the pareto front. + """Calculate pareto front based on non-dominance Parameters ---------- @@ -105,7 +108,7 @@ def sort_by_crowding_distance( sorted_list : list[Configuration] Configurations sorted by crowding distance. """ - F = _get_costs(runhistory, configs, config_instance_seed_budget_keys) + F = _get_costs(runhistory, configs, config_instance_seed_budget_keys, normalize=True) infinity = 1e14 n_points = F.shape[0] @@ -153,3 +156,62 @@ def sort_by_crowding_distance( config_with_crowding = sorted(config_with_crowding, key=lambda x: x[1], reverse=True) return [c for c, _ in config_with_crowding] + +def sort_by_hypervolume_contribution( + runhistory: RunHistory, + configs: list[Configuration], + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], +) -> list[Configuration]: + """ Sorts the passed configurations by their hypervolume contribution. + + Parameters + ---------- + runhistory : RunHistory + The runhistory containing the given configurations. + configs : list[Configuration] + The configurations which should be sorted. + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]] + The instance-seed budget keys for the configurations which should be sorted. + + Returns + ------- + sorted_list : list[Configuration] + Configurations sorted by hypervolume contribution. + """ + + # Get the average costs per configuration + + # Normalize the costs per objective + + # Compute a reference point (with the local points or all observed history) + + # Apply reduce procedure + + # Sort based on HV contribution + + raise NotImplementedError + +def calculate_hypervolume( + runhistory: RunHistory, + configs: list[Configuration], + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]], + reference_point: list[float] | None = None, +) -> float: + if reference_point is None: + reference_point = calculate_reference_point(runhistory) + + + raise NotImplementedError + +def calculate_reference_point( + runhistory: RunHistory, + configs: list[Configuration] | None = None, + config_instance_seed_budget_keys: list[list[InstanceSeedBudgetKey]] | None = None, +) -> list[float]: + if configs is None: + # Compute over the complete runhistory + costs = [trail.cost for trial in runhistory.values()] + return np.max(np.array(costs), axis=1) + else: + assert len(configs) == len(config_instance_seed_budget_keys) + raise NotImplementedError \ No newline at end of file diff --git a/tests/test_utils/test_pareto_front.py b/tests/test_utils/test_pareto_front.py index c1bbf14030..19ee35b298 100644 --- a/tests/test_utils/test_pareto_front.py +++ b/tests/test_utils/test_pareto_front.py @@ -30,15 +30,15 @@ def test_crowding_distance(configspace_small): configs = configspace_small.sample_configuration(20) config_instance_seed_budget_keys = [[isb_key]] * 20 - # Add points on pareto + # Add points on Pareto rh.add(configs[0], cost=[5, 5], instance=isb_key.instance, budget=isb_key.budget, seed=isb_key.seed) rh.add(configs[1], cost=[4, 6], instance=isb_key.instance, budget=isb_key.budget, seed=isb_key.seed) - # Add points not on pareto + # Add points not on Pareto rh.add(configs[2], cost=[5, 6], instance=isb_key.instance, budget=isb_key.budget, seed=isb_key.seed) rh.add(configs[3], cost=[5, 6], instance=isb_key.instance, budget=isb_key.budget, seed=isb_key.seed) - # Calculate pareto front + # Calculate Pareto front incumbents = calculate_pareto_front(rh, configs[:4], config_instance_seed_budget_keys[:4]) sorted_configs = sort_by_crowding_distance(rh, incumbents, config_instance_seed_budget_keys[: len(incumbents)]) # Nothing should happen if we only have two points on the pareto front