From a75db5243ddd74227b9151c0060192ef285849c3 Mon Sep 17 00:00:00 2001 From: Anselm Hahn Date: Mon, 5 Dec 2022 10:03:45 +0100 Subject: [PATCH] feat: :boom: Add _pseudo_ genetic search The proposed implementation of a genetic algorithm for hyper optimization. Even if genetic optimization might be costly for CNN, the applications in numeric analysis or Design of Experiment (DoE) make it still interesting. Fixes: #47 Further Reading: 1. [Vishwakarma G, et al Towards Autonomous Machine Learning in Chemistry via Evolutionary Algorithms. **ChemRxiv.**](https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/60c7445a337d6c2849e26d98/original/towards-autonomous-machine-learning-in-chemistry-via-evolutionary-algorithms.pdf) 2. [Rosanna Nichols et al 2019 _Quantum Sci. Technol._ **4** 045012](https://iopscience.iop.org/article/10.1088/2058-9565/ab4d89/meta?casa_token=db7uZRqRMEAAAAAA:fRO9qB25dAkeoskS6MMyzpZw2jSiMkpsN4zA_k6lheWUXaSUU8fPS-JPMoNFcIl9tka4OPCG5AtDtiM) --- keras_tuner/__init__.py | 1 + keras_tuner/tuners/__init__.py | 1 + keras_tuner/tuners/genetic.py | 533 ++++++++++++++++++++++++ keras_tuner/tuners/genetic_test.py | 646 +++++++++++++++++++++++++++++ setup.cfg | 6 + 5 files changed, 1187 insertions(+) create mode 100644 keras_tuner/tuners/genetic.py create mode 100644 keras_tuner/tuners/genetic_test.py diff --git a/keras_tuner/__init__.py b/keras_tuner/__init__.py index f3ebf4953..2e7469e77 100755 --- a/keras_tuner/__init__.py +++ b/keras_tuner/__init__.py @@ -25,6 +25,7 @@ from keras_tuner.engine.oracle import Oracle from keras_tuner.engine.tuner import Tuner from keras_tuner.tuners import BayesianOptimization +from keras_tuner.tuners import GeneticOptimization from keras_tuner.tuners import GridSearch from keras_tuner.tuners import Hyperband from keras_tuner.tuners import RandomSearch diff --git a/keras_tuner/tuners/__init__.py b/keras_tuner/tuners/__init__.py index 3fdb7491f..6ead3e5ce 100644 --- a/keras_tuner/tuners/__init__.py +++ b/keras_tuner/tuners/__init__.py @@ -14,6 +14,7 @@ from keras_tuner.tuners.bayesian import BayesianOptimization +from keras_tuner.tuners.genetic import GeneticOptimization from keras_tuner.tuners.gridsearch import GridSearch from keras_tuner.tuners.hyperband import Hyperband from keras_tuner.tuners.randomsearch import RandomSearch diff --git a/keras_tuner/tuners/genetic.py b/keras_tuner/tuners/genetic.py new file mode 100644 index 000000000..ee4edd393 --- /dev/null +++ b/keras_tuner/tuners/genetic.py @@ -0,0 +1,533 @@ +# Copyright 2019 The KerasTuner Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from copy import deepcopy + +import numpy as np + +from keras_tuner.engine import hyperparameters as hp_module +from keras_tuner.engine import oracle as oracle_module +from keras_tuner.engine import trial as trial_module +from keras_tuner.engine import tuner as tuner_module + + +class GeneticEvolutionaryProcess(object): + """Genetic Evolutionary Process with a population of chromosomes. + + Args: + mutation_factor: Float, the factor by which the hyperparameters are + mutated. + crossover_factor: Float, the factor by which the hyperparameters are + crossed over. + seed: Optional integer, the random seed. Defaults to None. + """ + + def __init__( + self, + mutation_factor, + crossover_factor, + seed=None, + ): + self.mutation_factor = mutation_factor + self.crossover_factor = crossover_factor + self.seed = seed + + def _initialize_population(self, life: hp_module.HyperParameters): + """Initialize the parents according to the life. + + Args: + life: A `HyperParameters` instances for the initial life. + + Returns: + A new mutated `HyperParameters` instances. + """ + return self._mutate(life.copy(), mutate_force=True) + + def _mutate(self, chromosome: hp_module.HyperParameters, mutate_force=False): + """Mutate a chromosome by sampling from the hyperparameter space. + + Args: + chromosome: A `HyperParameters` instance. + mutate_force: Boolean, whether to force mutation. + + Returns: + A mutated `HyperParameters` instance . + """ + if random.random() < self.mutation_factor or mutate_force: + mutated_chromosome = chromosome.copy() + mutated_values = { + hp.name: hp.random_sample(self.seed) for hp in chromosome.space + } + if self.seed is not None: + self.seed += 1 + mutated_chromosome.values = mutated_values + return mutated_chromosome + return chromosome + + def _crossover( + self, + parent_1: hp_module.HyperParameters, + parent_2: hp_module.HyperParameters, + ): + """Crossover two chromosomes. + + Args: + parent_1: A `HyperParameters` instance. + parent_2: A `HyperParameters` instance. + + Returns: + A Tuple of two `HyperParameters` instances. + """ + if random.random() < self.crossover_factor: + + # Select a random crossover point. + crossover_point = random.randint(0, len(parent_1.space) - 1) + parent_1_cross = parent_1.copy() + parent_2_cross = parent_2.copy() + + # Swap the hyperparameters after the crossover point. + for hp in parent_1.space[crossover_point:]: + parent_1_cross.values[hp.name] = parent_2.values[hp.name] + parent_2_cross.values[hp.name] = parent_1.values[hp.name] + + return parent_1_cross, parent_2_cross + return parent_1, parent_2 + + def _roulette_wheel_selection( + self, scores: list, population: list[hp_module.HyperParameters] + ): + """Perform roulette wheel selection for generating a couple. + + Args: + scores: A numpy array of scores. + population: A list of `HyperParameters` instances. + + Returns: + List of two `HyperParameters` instances. + """ + # Normalize the scores, if they are not equal. + if np.min(scores) != np.mean(scores): + scores -= np.min(scores) + scores /= np.sum(scores) + + # Generate two roulette wheel indices. + parent_index_1, parent_index_2 = random.choices( + range(len(population)), weights=scores, k=2 + ) + return population[parent_index_1], population[parent_index_2] + + def _tournament_selection( + self, scores: list, population: list[hp_module.HyperParameters] + ): + """Perform tournament selection for generating a couple. + + Args: + scores: A numpy array of scores. + population: A list of `HyperParameters` instances. + + Returns: + List of two `HyperParameters` instances. + """ + # Generate two tournament indices. + parent_index_1, parent_index_2 = random.choices(range(len(population)), k=2) + if scores[parent_index_1] > scores[parent_index_2]: + return population[parent_index_1], population[parent_index_2] + return population[parent_index_2], population[parent_index_1] + + +class GeneticOptimizationOracle(oracle_module.Oracle): + """Genetic algorithm tuner. + + This tuner uses a genetic algorithm to find the optimal hyperparameters. + It is a black-box algorithm, which means it does not require the model + to be compiled or trained. It works by keeping a population of models + and training each model for a few epochs. The models that perform best + are used to produce offspring for the next generation. A more detailed + description of the algorithm can be found at [here]( + https://link.springer.com/article/10.1007/BF02823145 + ) and [here]( + https://github.com/clever-algorithms/CleverAlgorithms + ). + + The `max_trials` parameter has to be calculated by the number of the + population size and the number of generations, and the number of + offspring times two. Because of the parent selection, the number of + offspring, respectively, is the number of the population size has to + be used two times for parents_1 and parents_2. + + + Args: + objective: A string, `keras_tuner.Objective` instance, or a list of + `keras_tuner.Objective`s and strings. If a string, the direction of + the optimization (min or max) will be inferred. If a list of + `keras_tuner.Objective`, we will minimize the sum of all the + objectives to minimize subtracting the sum of all the objectives to + maximize. The `objective` argument is optional when + `Tuner.run_trial()` or `HyperModel.fit()` returns a single float as + the objective to minimize. + generation_size: Integer, the number of generation to evolve the + offspring, respectively, the number of the offspring size. + Defaults to 10. + population_size: Integer, the number of models in the population at + each generation. Defaults to 10. + offspring_size: Integer, the number of offspring to produce at each + generation. By default, the offspring size is equal to the + population size. Defaults to None + mutation_factor: Float, the factor by which the hyperparameters are + mutated. Defaults to 0.9. + crossover_factor: Float, the factor by which the hyperparameters are + crossed over. Defaults to 0.1. + threshold: Float, the threshold for the fitness function. If the + fitness function is greater than the threshold, the search will + stop. Defaults to None. + selection_type: String, the type of selection to use for generating + the offspring. Can be either "roulette_wheel" or "tournament". + Defaults to "roulette_wheel". + seed: Optional integer, the random seed. + hyperparameters: Optional `HyperParameters` instance. Can be used to + override (or register in advance) hyperparameters in the search + space. + tune_new_entries: Boolean, whether hyperparameter entries that are + requested by the hypermodel but that were not specified in + `hyperparameters` should be added to the search space, or not. If + not, then the default value for these parameters will be used. + Defaults to True. + allow_new_entries: Boolean, whether the hypermodel is allowed + to request hyperparameters that were not specified in + `hyperparameters`. If not, then an error will be raised. Defaults + to True. + """ + + def __init__( + self, + objective=None, + generation_size=10, + population_size=10, + offspring_size=None, + mutation_factor=0.9, + crossover_factor=0.1, + threshold=None, + selection_type="tournament", + seed=None, + hyperparameters=None, + tune_new_entries=True, + allow_new_entries=True, + ): + self.generation_size = generation_size + self.population_size = population_size + self.offspring_size = offspring_size or population_size + self.max_trials = self._make_max_trials + + super(GeneticOptimizationOracle, self).__init__( + objective=objective, + max_trials=self.max_trials, + hyperparameters=hyperparameters, + tune_new_entries=tune_new_entries, + allow_new_entries=allow_new_entries, + seed=seed, + ) + + self.mutation_factor = mutation_factor + self.crossover_factor = crossover_factor + self.threshold = threshold + self.selection_type = selection_type + self.seed = seed or random.randint(1, int(1e4)) + self._seed_state = self.seed + self._random_state = np.random.RandomState(self.seed) + + if self.mutation_factor + self.crossover_factor > 1.0: + raise ValueError( + "The sum of the 'mutation_factors' and 'crossover_factors' " + "must be less than 1.0." + ) + if self.selection_type not in ["roulette_wheel", "tournament"]: + raise ValueError( + "The 'selection_type' must be either 'roulette_wheel' or " + "'tournament'." + ) + self._tried_so_far = set() + self._max_collisions = 20 + self.gep = self._make_gep() + self.ranges = self._make_ranges + self.population = {"hyperparameters": [], "scores": []} + self.new_population = {"hyperparameters": [], "scores": []} + self.values = {hp.name: hp.default for hp in self.get_space().space} + self.parent_1, self.parent_2 = None, None + + @property + def _make_max_trials(self): + """Calculate the maximum number of trials.""" + return self.population_size + self.generation_size * 2 * self.offspring_size + + @property + def _make_ranges(self): + """Make the ranges for genetic optimization with respect to max trial. + + Due to the fact that 'oracle' refers to the 'max_trial', + the ranges for the genetic optimization have to be calculated + for: + + 1. 'population_range': the number of models in the + population. + 2. 'offspring_ranges': the number of offspring to produce + at each generation. + 3. 'second_parent_range': the number of models to select + 4. 'generation_range': the number of generations, where the + best models are selected and the offspring are produced. + + Note, 'second_parent_range' is part of the 'offspring_ranges', + however, the second evaluation of the 'offspring_ranges' has + to be calculated separately, because only one value can be + returned once. + + Returns: + A dict of four lists: 'population_range', 'generation_range', + 'offspring_range', 'second_parent_range'. + """ + + population_range = list(range(self.population_size)) + generation_range = list( + range( + self.population_size + self.offspring_size * 2, + self.max_trials, + self.offspring_size * 2 - 1, + ) + ) + offspring_range = list(range(self.population_size, self.max_trials)) + first_parent_range = offspring_range[::2] + second_parent_range = offspring_range[1::2] + return { + "population": population_range, + "generation": generation_range, + "first_parent": first_parent_range, + "second_parent": second_parent_range, + } + + def _make_gep(self): + """Make a genetic evolutionary process. + + Returns: + A `GeneticEvolutionaryProcess` instance. + """ + return GeneticEvolutionaryProcess( + mutation_factor=self.mutation_factor, + crossover_factor=self.crossover_factor, + seed=self.seed, + ) + + def _check_score(self, score): + """Check if the current score is better than the best threshold. + + Args: + score: The current score. + + Returns: + A `dict` with the status and the current value. + """ + if self.threshold is not None and score <= self.threshold: + return { + "status": trial_module.TrialStatus.COMPLETED, + "values": self.values, + } + + @property + def _get_current_score(self): + """Get the current score. + + Returns: + A integer value of the current score. + """ + return self.trials[self.start_order[-1]].score + + def populate_space(self, trial_id): + """Populate the space for the genetic algorithm. + + The population is created by randomly sampling the hyperparameters + via mutation. The population is stored in the `hyperparameters` + attribute. The scores are stored in the `scores` attribute. + Next, the population is evaluated and the best models are selected + via 'tournament' or 'roulette_wheel' selection. The best models + will be used to crossover and mutate the offspring as new + population. Before the next generation is created, the 'population' + attribute is updated with the 'new_population' attributes and + the 'new_population' attribute is reset. The process is repeated + until the maximum number of trials is reached. + + Args: + trial_id: The current trial ID. + + Returns: + A dictionary of parameters for the current trial. + """ + if len(self.start_order) > 0: + # Start with population + if int(self.start_order[-1]) in self.ranges["population"]: + population = self.gep._initialize_population(self.hyperparameters) + self.values = population.values + self.population["hyperparameters"].append(population) + score = self._get_current_score + self.population["scores"].append(score) + self._check_score(score) + + # Start with generation and offspring + if int(self.start_order[-1]) in self.ranges["first_parent"]: + if self.selection_type == "tournament": + self.parent_1, self.parent_2 = self.gep._tournament_selection( + scores=self.population["scores"], + population=self.population["hyperparameters"], + ) + else: + ( + self.parent_1, + self.parent_2, + ) = self.gep._roulette_wheel_selection( + scores=self.population["scores"], + population=self.population["hyperparameters"], + ) + + self.parent_1, self.parent_2 = self.gep._crossover( + parent_1=self.parent_1, parent_2=self.parent_2 + ) + self.values = self.parent_1.values + self.new_population["hyperparameters"].append( + self.gep._mutate(self.parent_1) + ) + score = self._get_current_score + self.new_population["scores"].append(score) + self._check_score(score) + # Second parent for the offspring generation to be evaluated + elif int(self.start_order[-1]) in self.ranges["second_parent"]: + self.values = self.parent_2.values + self.new_population["hyperparameters"].append( + self.gep._mutate(self.parent_2) + ) + score = self._get_current_score + self.new_population["scores"].append(score) + self._check_score(score) + + if int(self.start_order[-1]) in self.ranges["generation"]: + self.population = deepcopy(self.new_population) + self.new_population = {"hyperparameters": [], "scores": []} + + if self.values is None: + return {"status": trial_module.TrialStatus.STOPPED, "values": None} + return {"status": trial_module.TrialStatus.RUNNING, "values": self.values} + + def get_state(self): + """Get the state of the genetic algorithm.""" + state = super(GeneticOptimizationOracle, self).get_state() + state.update( + { + "mutation_factor": self.mutation_factor, + "crossover_factor": self.crossover_factor, + } + ) + return state + + def set_state(self, state): + """Set the state of the genetic algorithm.""" + super(GeneticOptimizationOracle, self).set_state(state) + self.mutation_factor = state["mutation_factor"] + self.crossover_factor = state["crossover_factor"] + self.gep = self._make_gep() + + +# Generate Genetic Algorithm Tuner +class GeneticOptimization(tuner_module.Tuner): + """Genetic Optimization tuning with Genetic Evolutionary Process. + + Args: + hypermodel: Instance of `HyperModel` class (or callable that takes + hyperparameters and returns a `Model` instance). It is optional + when `Tuner.run_trial()` is overriden and does not use + `self.hypermodel`. + objective: A string, `keras_tuner.Objective` instance, or a list of + `keras_tuner.Objective`s and strings. If a string, the direction of + the optimization (min or max) will be inferred. If a list of + `keras_tuner.Objective`, we will minimize the sum of all the + objectives to minimize subtracting the sum of all the objectives to + maximize. The `objective` argument is optional when + `Tuner.run_trial()` or `HyperModel.fit()` returns a single float as + the objective to minimize. + generation_size: Integer, the number of generation to evolve the + offspring, respectively, the number of the offspring size. + Defaults to 10. + population_size: Integer, the number of models in the population at + each generation. Defaults to 10. + offspring_size: Integer, the number of offspring to produce at each + generation. By default, the offspring size is equal to the + population size. Defaults to None + mutation_factor: Float, the factor by which the hyperparameters are + mutated. Defaults to 0.9. + crossover_factor: Float, the factor by which the hyperparameters are + crossed over. Defaults to 0.1. + threshold: Float, the threshold for the fitness function. If the + fitness function is greater than the threshold, the search will + stop. Defaults to None. + selection_type: String, the type of selection to use for generating + the offspring. Can be either "roulette_wheel" or "tournament". + Defaults to "roulette_wheel". + seed: Optional integer, the random seed. + hyperparameters: Optional `HyperParameters` instance. Can be used to + override (or register in advance) hyperparameters in the search + space. + tune_new_entries: Boolean, whether hyperparameter entries that are + requested by the hypermodel but that were not specified in + `hyperparameters` should be added to the search space, or not. If + not, then the default value for these parameters will be used. + Defaults to True. + allow_new_entries: Boolean, whether the hypermodel is allowed to + request hyperparameter entries not listed in `hyperparameters`. + Defaults to True. + **kwargs: Keyword arguments relevant to all `Tuner` subclasses. Please + see the docstring for `Tuner`. + """ + + def __init__( + self, + hypermodel=None, + objective=None, + generation_size=10, + population_size=10, + offspring_size=None, + mutation_factor=0.9, + crossover_factor=0.1, + threshold=None, + selection_type="tournament", + seed=None, + hyperparameters=None, + tune_new_entries=True, + allow_new_entries=True, + **kwargs + ): + + oracle = GeneticOptimizationOracle( + objective=objective, + generation_size=generation_size, + population_size=population_size, + offspring_size=offspring_size, + mutation_factor=mutation_factor, + crossover_factor=crossover_factor, + threshold=threshold, + selection_type=selection_type, + seed=seed, + hyperparameters=hyperparameters, + tune_new_entries=tune_new_entries, + allow_new_entries=allow_new_entries, + ) + super( + GeneticOptimization, + self, + ).__init__(oracle=oracle, hypermodel=hypermodel, **kwargs) diff --git a/keras_tuner/tuners/genetic_test.py b/keras_tuner/tuners/genetic_test.py new file mode 100644 index 000000000..22fa9ee6d --- /dev/null +++ b/keras_tuner/tuners/genetic_test.py @@ -0,0 +1,646 @@ +# Copyright 2019 The KerasTuner Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import tensorflow as tf + +import keras_tuner +from keras_tuner.engine import hyperparameters as hp_module +from keras_tuner.tuners import genetic as ge_module + + +def build_model(hp): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Flatten(input_shape=(2, 2))) + for i in range(3): + model.add( + tf.keras.layers.Dense( + units=hp.Int("units_" + str(i), 2, 4, 2), activation="relu" + ) + ) + model.add(tf.keras.layers.Dense(2, activation="softmax")) + model.compile( + optimizer=tf.keras.optimizers.Adam( + hp.Choice("learning_rate", [1e-2, 1e-3, 1e-4]) + ), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + return model + + +def test_mutation(): + """Test mutation of a chromosome.""" + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + gep = ge_module.GeneticEvolutionaryProcess( + mutation_factor=1.1, crossover_factor=0.5 + ) + mutated = gep._mutate(hps) + assert mutated.values != hps.values + assert mutated.get("a") in [1, 2, 3] + assert 0 <= mutated.get("b") <= 1 + assert 0 <= mutated.get("c") <= 10 + assert mutated.get("d") == 1 + + +def test_no_mutation(): + """Test that no mutation occurs when mutation factor is 0.""" + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + gep = ge_module.GeneticEvolutionaryProcess( + mutation_factor=-1, crossover_factor=0.5 + ) + mutated = gep._mutate(hps) + assert mutated.values == hps.values + + +def test_crossover(): + """Test crossover of two chromosomes.""" + hp1 = hp_module.HyperParameters() + hp1.Choice("a", [1, 2, 3]) + hp1.Float("b", 0, 1, step=0.1) + hp1.Int("c", 0, 10, step=2) + hp1.Fixed("d", 1) + + hp2 = hp_module.HyperParameters() + hp2.Choice("a", [4, 5, 1]) + hp2.Float("b", -1, 0, step=0.1) + hp2.Int("c", 10, 20, step=2) + hp2.Fixed("d", 2) + + gep = ge_module.GeneticEvolutionaryProcess( + mutation_factor=0, crossover_factor=1.1 + ) + parent_1, parent_2 = gep._crossover(hp1, hp2) + assert parent_1.values != hp1.values + assert parent_2.values != hp2.values + assert 0 <= parent_1.get("b") <= 1 + assert -1 <= parent_2.get("b") <= 0 + + +def test_no_crossover(): + """Test that no crossover occurs when crossover factor is 0.""" + hp1 = hp_module.HyperParameters() + hp1.Choice("a", [1, 2, 3]) + hp1.Float("b", 0, 1) + hp1.Int("c", 0, 10) + hp1.Fixed("d", 1) + + hp2 = hp_module.HyperParameters() + hp2.Choice("a", [4, 5, 1]) + hp2.Float("b", -1, 0) + hp2.Int("c", 10, 20) + hp2.Fixed("d", 2) + + gep = ge_module.GeneticEvolutionaryProcess( + mutation_factor=0, crossover_factor=-1 + ) + parent_1, parent_2 = gep._crossover(hp1, hp2) + assert parent_1.values == hp1.values + assert parent_2.values == hp2.values + + +def test_make_ranges_without_offspring_size(): + """Test that the ranges are created correctly.""" + + goo = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=2, + population_size=10, + offspring_size=None, + ) + + ranges = goo._make_ranges + assert goo.max_trials == 10 + 2 * 10 * 2 + assert ranges["population"] == list(range(10)) + assert ranges["generation"] == list(range(30, 50, 19)) + assert ranges["first_parent"] == list(range(10, 50))[::2] + assert ranges["second_parent"] == list(range(10, 50))[1::2] + + +def test_make_with_offspring_size(): + """Test that the ranges are created correctly.""" + + goo = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=2, + population_size=10, + offspring_size=5, + ) + + ranges = goo._make_ranges + assert goo.max_trials == 10 + 2 * 5 * 2 + assert ranges["population"] == list(range(10)) + assert ranges["generation"] == list(range(20, 30, 9)) + assert ranges["first_parent"] == list(range(10, 30))[::2] + assert ranges["second_parent"] == list(range(10, 30))[1::2] + + +def test_raises_factor(): + """Test that the crossover and mutation factors are in the correct limit.""" + + with pytest.raises(ValueError) as excinfo: + ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=2, + population_size=21, + offspring_size=5, + mutation_factor=1.2, + crossover_factor=0.5, + ) + + assert ( + "The sum of the 'mutation_factors' and " + "'crossover_factors' must be less than 1.0." in str(excinfo.value) + ) + + +def test_raises_selection(): + """Test that the selection method is correct.""" + + with pytest.raises(ValueError) as excinfo: + ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=2, + population_size=21, + offspring_size=5, + selection_type="random", + ) + assert ( + "The 'selection_type' must be either 'roulette_wheel' or " + "'tournament'." in str(excinfo.value) + ) + + +def test_roulette_wheel_selection(): + """Test that the roulette wheel selection works correctly.""" + + scores = [1.0, 2.0, 3.0, 4.0, 5.0] + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + gep = ge_module.GeneticEvolutionaryProcess(mutation_factor=1, crossover_factor=0) + population = [gep._mutate(hps) for _ in range(5)] + + parent_1, parent_2 = gep._roulette_wheel_selection( + scores=scores, population=population + ) + assert parent_1.values != parent_2.values + assert len(parent_1.values) == len(parent_2.values) + assert len(parent_1.values) == len(hps.values) + assert len(parent_2.values) == len(hps.values) + assert parent_1.values in [hp.values for hp in population] + assert parent_2.values in [hp.values for hp in population] + + +def test_tournament_selection(): + """Test that the tournament selection works correctly.""" + + scores = [1.0, 2.0, 3.0, 4.0, 5.0] + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + gep = ge_module.GeneticEvolutionaryProcess(mutation_factor=1, crossover_factor=0) + population = [gep._mutate(hps) for _ in range(5)] + + parent_1, parent_2 = gep._tournament_selection( + scores=scores, population=population + ) + assert parent_1.values != parent_2.values + assert len(parent_1.values) == len(parent_2.values) + assert len(parent_1.values) == len(hps.values) + assert len(parent_2.values) == len(hps.values) + assert parent_1.values in [hp.values for hp in population] + assert parent_2.values in [hp.values for hp in population] + + +def test_populate_space_init(tmp_path): + """Test that the population is initialized correctly.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + ) + oracle._set_project_dir(tmp_path, "untitled") + oracle._populate_space("00") + assert oracle.population_size == 10 + assert oracle.offspring_size == 5 + + +def test_populate_inits(tmp_path): + """Test that the initiale batch is true random.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(10): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][2] + ) + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][3] + ) + assert oracle.population["scores"][1] == oracle.population["scores"][4] + + +def test_score_early_ridged_min(tmp_path): + """Test early stopping with a min objective.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "min"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + threshold=0.1, + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(2): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + _trial = oracle._check_score(0.01) + assert _trial["status"] == "COMPLETED" + assert _trial["values"] != hps.values + + +def test_score_stopped(tmp_path): + """Test early stopping with a min objective.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "min"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + threshold=0.1, + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(1): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + oracle.values = None + oracle.start_order = [] + _trial = oracle.populate_space("00") + assert _trial["status"] == "STOPPED" + assert _trial["values"] is None + + +def test_get_set_state(tmp_path): + """Test that the state is correctly set and retrieved.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1) + hps.Int("c", 0, 10) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "min"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + threshold=0.1, + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(4): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + state = oracle.get_state() + oracle.set_state(state) + assert oracle.population_size == 10 + assert oracle.offspring_size == 5 + assert oracle.generation_size == 5 + assert oracle.threshold == 0.1 + assert oracle.population["scores"][0] == 0.2 + assert oracle.population["scores"][1] == 0.2 + assert oracle.population["hyperparameters"][0] != hps.values + assert oracle.population["hyperparameters"][1] != hps.values + assert ( + oracle.population["hyperparameters"][0] + != oracle.population["hyperparameters"][1] + ) + assert ( + oracle.population["hyperparameters"][0] + != oracle.population["hyperparameters"][2] + ) + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][2] + ) + + +def test_tournament(tmp_path): + """Test that the tournament selection works correctly.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1, step=0.1) + hps.Int("c", 0, 10, step=9) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + selection_type="tournament", + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(60): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][2] + ) + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][3] + ) + assert oracle.population["scores"][1] == oracle.population["scores"][4] + + +def test_roulette(tmp_path): + """Test that the roulette selection works correctly.""" + + hps = hp_module.HyperParameters() + hps.Choice("a", [1, 2, 3]) + hps.Float("b", 0, 1, step=0.1) + hps.Int("c", 0, 10, step=9) + hps.Fixed("d", 1) + + oracle = ge_module.GeneticOptimizationOracle( + objective=keras_tuner.Objective("val_accuracy", "max"), + generation_size=5, + population_size=10, + offspring_size=5, + hyperparameters=hps, + selection_type="roulette_wheel", + ) + oracle._set_project_dir(tmp_path, "untitled") + for _ in range(60): + trial = oracle.create_trial("trial_id") + oracle.update_trial(trial.trial_id, {"val_accuracy": 0.2}) + oracle.end_trial(trial.trial_id, "COMPLETED") + + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][2] + ) + assert ( + oracle.population["hyperparameters"][1] + != oracle.population["hyperparameters"][3] + ) + assert oracle.population["scores"][1] == oracle.population["scores"][4] + + +def test_genetic_minimize_tournament(tmp_path): + """Test that the tournament selection works correctly for minimization.""" + + class MyTuner(keras_tuner.GeneticOptimization): + def run_trial(self, trial, *args, **kwargs): + # Get the hp from trial. + hp = trial.hyperparameters + # Define "x" as a hyperparameter. + x = hp.Float("x", min_value=-1.0, max_value=1.0, step=0.001) + # Return the objective value to minimize. + return x * x + 1 + + tuner = MyTuner( + # No hypermodel or objective specified. + overwrite=True, + directory=tmp_path, + project_name="tune_anything", + objective=keras_tuner.Objective("score", "min"), + population_size=10, + offspring_size=5, + generation_size=5, + mutation_factor=0.9, + crossover_factor=0.1, + selection_type="tournament", + ) + + # No need to pass anything to search() + # unless you use them in run_trial(). + tuner.search() + assert tuner.oracle.selection_type == "tournament" + assert np.isclose(tuner.get_best_hyperparameters()[0].get("x"), 0.0, atol=1e-1) + + +def test_genetic_minimize_roulette(tmp_path): + """Test that the roulette selection works correctly for minimization.""" + + class MyTuner(keras_tuner.GeneticOptimization): + def run_trial(self, trial, *args, **kwargs): + # Get the hp from trial. + hp = trial.hyperparameters + # Define "x" as a hyperparameter. + x = hp.Float("x", min_value=-1.0, max_value=1.0, step=0.001) + # Return the objective value to minimize. + return x * x + 1 + + tuner = MyTuner( + # No hypermodel or objective specified. + overwrite=True, + directory=tmp_path, + project_name="tune_anything", + objective=keras_tuner.Objective("score", "min"), + population_size=10, + offspring_size=5, + generation_size=5, + mutation_factor=0.9, + crossover_factor=0.1, + selection_type="roulette_wheel", + ) + + # No need to pass anything to search() + # unless you use them in run_trial(). + tuner.search() + assert tuner.oracle.selection_type == "roulette_wheel" + assert np.isclose(tuner.get_best_hyperparameters()[0].get("x"), 0.0, atol=1e-1) + + +def test_genetic_maximize_tournament(tmp_path): + """Test that the tournament selection works correctly for maximization.""" + + class MyTuner(keras_tuner.GeneticOptimization): + def run_trial(self, trial, *args, **kwargs): + # Get the hp from trial. + hp = trial.hyperparameters + # Define "x" as a hyperparameter. + x = hp.Float("x", min_value=-1.0, max_value=1.0, step=0.001) + # Return the objective value to maximize. + return x * x + 1 + + tuner = MyTuner( + # No hypermodel or objective specified. + overwrite=True, + directory=tmp_path, + project_name="tune_anything", + objective=keras_tuner.Objective("score", "max"), + population_size=10, + offspring_size=5, + generation_size=5, + mutation_factor=0.9, + crossover_factor=0.1, + selection_type="tournament", + ) + + # No need to pass anything to search() + # unless you use them in run_trial(). + tuner.search() + assert tuner.oracle.selection_type == "tournament" + assert np.isclose(tuner.get_best_hyperparameters()[0].get("x"), -1.0, atol=1e-1) + + +def test_genetic_maximize_roulette(tmp_path): + """Test that the roulette selection works correctly for maximization.""" + + class MyTuner(keras_tuner.GeneticOptimization): + def run_trial(self, trial, *args, **kwargs): + # Get the hp from trial. + hp = trial.hyperparameters + # Define "x" as a hyperparameter. + x = hp.Float("x", min_value=-1.0, max_value=1.0, step=0.001) + # Return the objective value to maximize. + return x * x + 1 + + tuner = MyTuner( + # No hypermodel or objective specified. + overwrite=True, + directory=tmp_path, + project_name="tune_anything", + objective=keras_tuner.Objective("score", "max"), + population_size=10, + offspring_size=5, + generation_size=5, + mutation_factor=0.9, + crossover_factor=0.1, + selection_type="roulette_wheel", + ) + + # No need to pass anything to search() + # unless you use them in run_trial(). + tuner.search() + assert tuner.oracle.selection_type == "roulette_wheel" + assert np.isclose(tuner.get_best_hyperparameters()[0].get("x"), -1.0, atol=1e-1) + + +def test_genetic_minimize_tournament_with_hypermodel(tmp_path): + """Test that the tournament selection works correctly for mandelbrot.""" + + class MandelBrotTuner(keras_tuner.GeneticOptimization): + def run_trial(self, trial, *args, **kwargs): + # Get the hp from trial. + hp = trial.hyperparameters + # Define "x" as a hyperparameter. + x = hp.Float("x", min_value=-2.0, max_value=2.0, step=0.001) + y = hp.Float("y", min_value=-2.0, max_value=2.0, step=0.001) + # Return the objective value to minimize. + return self.mandelbrot(x, y) + + def mandelbrot(self, x, y): + c = complex(x, y) + z = 0.0j + for i in range(1, 1000): + z = z**2 + c + if (z.real * z.real + z.imag * z.imag) >= 4: + return i + return 0 + + tuner = MandelBrotTuner( + overwrite=True, + directory=tmp_path, + project_name="tune_anything", + objective=keras_tuner.Objective("score", "min"), + population_size=20, + offspring_size=25, + generation_size=50, + mutation_factor=0.9, + crossover_factor=0.1, + selection_type="roulette_wheel", + threshold=0.1, + ) + tuner.search() + assert tuner.oracle.selection_type == "roulette_wheel" + assert np.isclose( + tuner.get_best_hyperparameters()[0].get("x"), -0.363, atol=1e-2 + ) + assert np.isclose( + tuner.get_best_hyperparameters()[0].get("y"), -0.363, atol=1e-2 + ) diff --git a/setup.cfg b/setup.cfg index 2d1068bc4..ac84a97db 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,12 @@ ignore = E121,E123,E126,E226,E24,E704,W503,W504 # Function name should be lowercase N802 + # Argument name should be lowercase + N803 + # First argument of a classmethod should be named 'cls' + N804 + # Function name should be lowercase + N806 # lowercase ... imported as non lowercase # Useful to ignore for "import keras.backend as K" N812