From f32135d3b2b1e1967e721f969f9158b4ab1233a8 Mon Sep 17 00:00:00 2001 From: herilalaina Date: Fri, 7 Jul 2023 09:36:27 +0200 Subject: [PATCH 1/2] remove hp --- ConfigSpace/conditions.pyx | 9 ++++++ ConfigSpace/forbidden.pyx | 14 +++++++- ConfigSpace/util.py | 65 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/ConfigSpace/conditions.pyx b/ConfigSpace/conditions.pyx index 90e71c26..df41078a 100644 --- a/ConfigSpace/conditions.pyx +++ b/ConfigSpace/conditions.pyx @@ -74,6 +74,9 @@ cdef class ConditionComponent(object): def evaluate_vector(self, instantiated_vector): return bool(self._evaluate_vector(instantiated_vector)) + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + pass + cdef int _evaluate_vector(self, np.ndarray value): pass @@ -147,6 +150,9 @@ cdef class AbstractCondition(ConditionComponent): hp_name = self.parent.name return self._evaluate(instantiated_parent_hyperparameter[hp_name]) + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + return self.get_children() + self.get_parents() + cdef int _evaluate_vector(self, np.ndarray instantiated_vector): if self.parent_vector_id is None: raise ValueError("Parent vector id should not be None when calling evaluate vector") @@ -523,6 +529,9 @@ cdef class AbstractConjunction(ConditionComponent): raise ValueError("All Conjunctions and Conditions must have " "the same child.") + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + return sum([cond.get_referenced_hyperparameters() for cond in self.components], []) + def __eq__(self, other: Any) -> bool: """ This method implements a comparison between self and another diff --git a/ConfigSpace/forbidden.pyx b/ConfigSpace/forbidden.pyx index 9ba0dd30..062a59c3 100644 --- a/ConfigSpace/forbidden.pyx +++ b/ConfigSpace/forbidden.pyx @@ -31,7 +31,7 @@ import numpy as np import io from ConfigSpace.hyperparameters import Hyperparameter from ConfigSpace.hyperparameters.hyperparameter cimport Hyperparameter -from typing import Dict, Any, Union +from typing import Dict, Any, Union, List from ConfigSpace.forbidden cimport AbstractForbiddenComponent @@ -71,6 +71,9 @@ cdef class AbstractForbiddenComponent(object): return (self.value == other.value and self.hyperparameter.name == other.hyperparameter.name) + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + pass + def __hash__(self) -> int: """Override the default hash behavior (that returns the id or the object)""" return hash(tuple(sorted(self.__dict__.items()))) @@ -103,6 +106,9 @@ cdef class AbstractForbiddenClause(AbstractForbiddenComponent): self.hyperparameter = hyperparameter self.vector_id = -1 + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + return [self.hyperparameter] + cpdef get_descendant_literal_clauses(self): return (self, ) @@ -315,6 +321,9 @@ cdef class AbstractForbiddenConjunction(AbstractForbiddenComponent): self.n_components = len(self.components) self.dlcs = self.get_descendant_literal_clauses() + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + return sum([forb.get_referenced_hyperparameters() for forb in self.components], []) + def __repr__(self): pass @@ -497,6 +506,9 @@ cdef class ForbiddenRelation(AbstractForbiddenComponent): self.right = right self.vector_ids = (-1, -1) + def get_referenced_hyperparameters(self) -> List[Hyperparameter]: + return [self.left, self.right] + def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False diff --git a/ConfigSpace/util.py b/ConfigSpace/util.py index 120bfce3..87622de1 100644 --- a/ConfigSpace/util.py +++ b/ConfigSpace/util.py @@ -708,3 +708,68 @@ def get_cartesian_product(value_sets: list[tuple], hp_names: list[str]) -> list[ unchecked_grid_pts.popleft() return checked_grid_pts + + +def remove_hyperparameter(name: str, configuration_space: ConfigurationSpace) -> ConfigurationSpace: + """ + Returns a new configuration space with the hyperparameter removed. + + Parameters + ---------- + name: str + Name of the hyperparameter to remove + + configuration_space: :class:`~ConfigSpace.configuration_space.ConfigurationSpace` + Configuration space from which to remove the hyperparameter. + + Returns + ------- + :class:`~ConfigSpace.configuration_space.Configuration` + A new configuration space without the hyperparameter + """ + if name not in configuration_space._hyperparameters: + raise ValueError(f"{name} not in {configuration_space}") + + hp_to_remove = configuration_space.get_hyperparameter(name) + hps = [ + copy(hp) # type: ignore + for hp in configuration_space.get_hyperparameters() + if hp.name != name + ] + + conditions = [ + cond + for cond in configuration_space.get_conditions() + if hp_to_remove not in cond.get_referenced_hyperparameters() + ] + forbiddens = [ + forbidden + for forbidden in configuration_space.get_forbiddens() + if hp_to_remove not in forbidden.get_referenced_hyperparameters() + ] + + if isinstance(configuration_space.random, np.random.RandomState): + new_seed = configuration_space.random.randint(2**32 - 1) + else: + new_seed = copy(configuration_space.random) # type: ignore + + new_space = ConfigurationSpace( + seed=new_seed, + name=copy(configuration_space.name), # type: ignore + meta=copy(configuration_space.meta), # type: ignore + ) + new_space.add_hyperparameters(hps) + + new_conditions = ConfigurationSpace.substitute_hyperparameters_in_conditions( + conditions=conditions, + new_configspace=new_space, + ) + new_forbiddens = ConfigurationSpace.substitute_hyperparameters_in_forbiddens( + forbiddens=forbiddens, + new_configspace=new_space, + ) + + new_space.add_conditions(new_conditions) + new_space.add_forbidden_clauses(new_forbiddens) + + return new_space From 25c04194ba9ea23d2d17e5d54c20bca94297bd47 Mon Sep 17 00:00:00 2001 From: herilalaina Date: Fri, 7 Jul 2023 13:47:43 +0200 Subject: [PATCH 2/2] add tests --- ConfigSpace/util.py | 22 +++++++++++--------- test/test_util.py | 50 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/ConfigSpace/util.py b/ConfigSpace/util.py index 87622de1..7166c67a 100644 --- a/ConfigSpace/util.py +++ b/ConfigSpace/util.py @@ -699,7 +699,7 @@ def get_cartesian_product(value_sets: list[tuple], hp_names: list[str]) -> list[ if len(new_active_hp_names) <= 0: raise RuntimeError( "Unexpected error: There should have been a newly activated hyperparameter" - f" for the current configuration values: {str(unchecked_grid_pts[0])}. " + f" for the current configuration values: {unchecked_grid_pts[0]!s}. " "Please contact the developers with the code you ran and the stack trace.", ) from None @@ -730,9 +730,16 @@ def remove_hyperparameter(name: str, configuration_space: ConfigurationSpace) -> if name not in configuration_space._hyperparameters: raise ValueError(f"{name} not in {configuration_space}") + # First, delete children hyperparameters + for child in configuration_space._children[name]: # type: ignore + configuration_space = remove_hyperparameter( + name=child, + configuration_space=configuration_space, + ) + hp_to_remove = configuration_space.get_hyperparameter(name) hps = [ - copy(hp) # type: ignore + copy.deepcopy(hp) # type: ignore for hp in configuration_space.get_hyperparameters() if hp.name != name ] @@ -748,17 +755,12 @@ def remove_hyperparameter(name: str, configuration_space: ConfigurationSpace) -> if hp_to_remove not in forbidden.get_referenced_hyperparameters() ] - if isinstance(configuration_space.random, np.random.RandomState): - new_seed = configuration_space.random.randint(2**32 - 1) - else: - new_seed = copy(configuration_space.random) # type: ignore - new_space = ConfigurationSpace( - seed=new_seed, - name=copy(configuration_space.name), # type: ignore - meta=copy(configuration_space.meta), # type: ignore + name=copy.deepcopy(configuration_space.name), # type: ignore + meta=copy.deepcopy(configuration_space.meta), # type: ignore ) new_space.add_hyperparameters(hps) + new_space.random.set_state(configuration_space.random.get_state()) new_conditions = ConfigurationSpace.substitute_hyperparameters_in_conditions( conditions=conditions, diff --git a/test/test_util.py b/test/test_util.py index 3cd6bfdc..b8a32c18 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -58,6 +58,7 @@ get_one_exchange_neighbourhood, get_random_neighbor, impute_inactive_values, + remove_hyperparameter, ) @@ -625,3 +626,52 @@ def test_generate_grid(self): assert dict(generated_grid[1]) == {"cat1": "F", "ord1": "2"} assert dict(generated_grid[2]) == {"cat1": "T", "ord1": "1", "int1": 0} assert dict(generated_grid[-1]) == {"cat1": "T", "ord1": "3", "int1": 1000} + + def test_remove_hyperparameter(self): + """Test removing hyperparameter.""" + cs = ConfigurationSpace(seed=1234) + + list_hps = ["cat1", "const1", "float1", "int1", "ord1"] + cat1 = CategoricalHyperparameter(name="cat1", choices=["T", "F"]) + const1 = Constant(name="const1", value=4) + float1 = UniformFloatHyperparameter(name="float1", lower=-1, upper=1, log=False) + int1 = UniformIntegerHyperparameter(name="int1", lower=10, upper=100, log=True) + ord1 = OrdinalHyperparameter(name="ord1", sequence=["1", "2", "3"]) + + cs.add_hyperparameters([float1, int1, cat1, ord1, const1]) + + # test exception if hyperparamter is not in the configuration space + with self.assertRaises(ValueError): + remove_hyperparameter(name="cat", configuration_space=cs) + + for hp_to_remove in list_hps: + cs1 = remove_hyperparameter(name=hp_to_remove, configuration_space=cs) + + # verify that the hyperparameter is not in the configuration space anymore + assert hp_to_remove not in cs1._hyperparameters + + # the other hyperparameters remain in the configuration space + remaining_hps = [hp for hp in list_hps if hp != hp_to_remove] + for hp_name in remaining_hps: + assert hp_name in cs1._hyperparameters + + cs = ConfigurationSpace(seed=1234) + cs.add_hyperparameters([float1, int1, cat1, ord1, const1]) + cs.add_condition(EqualsCondition(int1, cat1, "T")) # int1 only active if cat1 == T + cs.add_forbidden_clause( + ForbiddenAndConjunction( # Forbid ord1 == 3 if cat1 == F + ForbiddenEqualsClause(cat1, "F"), + ForbiddenEqualsClause(ord1, "3"), + ), + ) + assert len(cs.get_conditions()) == 1 + assert len(cs.get_forbiddens()) == 1 + + # check that children hyperparameters are also removed + cs1 = remove_hyperparameter(name="cat1", configuration_space=cs) + assert "cat1" not in cs1._hyperparameters + assert "int1" not in cs1._hyperparameters + + # check that referenced conditions and clauses are also removed + assert len(cs1.get_conditions()) == 0 + assert len(cs1.get_forbiddens()) == 0