Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove hyperparameter in a configuration space #332

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ConfigSpace/conditions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ cdef class ConditionComponent(object):
def evaluate_vector(self, instantiated_vector):
return bool(self._evaluate_vector(instantiated_vector))

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
pass

cdef int _evaluate_vector(self, np.ndarray value):
pass

Expand Down Expand Up @@ -147,6 +150,9 @@ cdef class AbstractCondition(ConditionComponent):
hp_name = self.parent.name
return self._evaluate(instantiated_parent_hyperparameter[hp_name])

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
return self.get_children() + self.get_parents()

cdef int _evaluate_vector(self, np.ndarray instantiated_vector):
if self.parent_vector_id is None:
raise ValueError("Parent vector id should not be None when calling evaluate vector")
Expand Down Expand Up @@ -523,6 +529,9 @@ cdef class AbstractConjunction(ConditionComponent):
raise ValueError("All Conjunctions and Conditions must have "
"the same child.")

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
return sum([cond.get_referenced_hyperparameters() for cond in self.components], [])
herilalaina marked this conversation as resolved.
Show resolved Hide resolved

def __eq__(self, other: Any) -> bool:
"""
This method implements a comparison between self and another
Expand Down
14 changes: 13 additions & 1 deletion ConfigSpace/forbidden.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import numpy as np
import io
from ConfigSpace.hyperparameters import Hyperparameter
from ConfigSpace.hyperparameters.hyperparameter cimport Hyperparameter
from typing import Dict, Any, Union
from typing import Dict, Any, Union, List

from ConfigSpace.forbidden cimport AbstractForbiddenComponent

Expand Down Expand Up @@ -71,6 +71,9 @@ cdef class AbstractForbiddenComponent(object):
return (self.value == other.value and
self.hyperparameter.name == other.hyperparameter.name)

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
pass

def __hash__(self) -> int:
"""Override the default hash behavior (that returns the id or the object)"""
return hash(tuple(sorted(self.__dict__.items())))
Expand Down Expand Up @@ -103,6 +106,9 @@ cdef class AbstractForbiddenClause(AbstractForbiddenComponent):
self.hyperparameter = hyperparameter
self.vector_id = -1

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
return [self.hyperparameter]

cpdef get_descendant_literal_clauses(self):
return (self, )

Expand Down Expand Up @@ -315,6 +321,9 @@ cdef class AbstractForbiddenConjunction(AbstractForbiddenComponent):
self.n_components = len(self.components)
self.dlcs = self.get_descendant_literal_clauses()

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
return sum([forb.get_referenced_hyperparameters() for forb in self.components], [])

def __repr__(self):
pass

Expand Down Expand Up @@ -497,6 +506,9 @@ cdef class ForbiddenRelation(AbstractForbiddenComponent):
self.right = right
self.vector_ids = (-1, -1)

def get_referenced_hyperparameters(self) -> List[Hyperparameter]:
return [self.left, self.right]

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
Expand Down
65 changes: 65 additions & 0 deletions ConfigSpace/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,68 @@ def get_cartesian_product(value_sets: list[tuple], hp_names: list[str]) -> list[
unchecked_grid_pts.popleft()

return checked_grid_pts


def remove_hyperparameter(name: str, configuration_space: ConfigurationSpace) -> ConfigurationSpace:
"""
Returns a new configuration space with the hyperparameter removed.

Parameters
----------
name: str
Name of the hyperparameter to remove

configuration_space: :class:`~ConfigSpace.configuration_space.ConfigurationSpace`
Configuration space from which to remove the hyperparameter.

Returns
-------
:class:`~ConfigSpace.configuration_space.Configuration`
A new configuration space without the hyperparameter
"""
if name not in configuration_space._hyperparameters:
raise ValueError(f"{name} not in {configuration_space}")

hp_to_remove = configuration_space.get_hyperparameter(name)
hps = [
copy(hp) # type: ignore
for hp in configuration_space.get_hyperparameters()
if hp.name != name
]

conditions = [
cond
for cond in configuration_space.get_conditions()
if hp_to_remove not in cond.get_referenced_hyperparameters()
]
forbiddens = [
forbidden
for forbidden in configuration_space.get_forbiddens()
if hp_to_remove not in forbidden.get_referenced_hyperparameters()
]

if isinstance(configuration_space.random, np.random.RandomState):
new_seed = configuration_space.random.randint(2**32 - 1)
herilalaina marked this conversation as resolved.
Show resolved Hide resolved
else:
new_seed = copy(configuration_space.random) # type: ignore

new_space = ConfigurationSpace(
seed=new_seed,
name=copy(configuration_space.name), # type: ignore
meta=copy(configuration_space.meta), # type: ignore
)
new_space.add_hyperparameters(hps)

new_conditions = ConfigurationSpace.substitute_hyperparameters_in_conditions(
conditions=conditions,
new_configspace=new_space,
)
new_forbiddens = ConfigurationSpace.substitute_hyperparameters_in_forbiddens(
forbiddens=forbiddens,
new_configspace=new_space,
)

new_space.add_conditions(new_conditions)
new_space.add_forbidden_clauses(new_forbiddens)

return new_space
Loading