Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An implementation to support complex conditions such as 'a < b and a * b < 10 and ...' #277

Open
jhj0411jhj opened this issue Nov 15, 2022 · 3 comments

Comments

@jhj0411jhj
Copy link

I implement a ConditionedConfigurationSpace that supports complex conditions between hyperparameters (e.g., x1 <= x2 and x1 * x2 < 100). User can define a sample_condition function to restrict the generation of configurations.

The following functions are guaranteed to return valid configurations:

  • self.sample_configuration()
  • get_one_exchange_neighbourhood() # may return empty list

Here is an example:

def sample_condition(config):
    # require x1 <= x2 and x1 * x2 < 100
    if config['x1'] > config['x2']:
        return False
    if config['x1'] * config['x2'] >= 100:
        return False
    return True
    # return config['x1'] <= config['x2'] and config['x1'] * config['x2'] < 100

cs = ConditionedConfigurationSpace()
cs.add_hyperparameters([...])
cs.set_sample_condition(sample_condition)  # set the sample condition after all hyperparameters are added
configs = cs.sample_configuration(1000)

Implementing this feature using fobiddens like ForbiddenClause or ForbiddenRelation might be a viable option, but it's a little complicated, and user may need to pass the full config space to the forbidden object.

The implementation does not consider serialization (of sample_condition function).

Here is the implementation of ConditionedConfigurationSpace:

from typing import List, Union, Callable
import numpy as np
from ConfigSpace import ConfigurationSpace, Configuration
from ConfigSpace.exceptions import ForbiddenValueError


class ConditionedConfigurationSpace(ConfigurationSpace):
    """
    A configuration space that supports complex conditions between hyperparameters,
        e.g., x1 <= x2 and x1 * x2 < 100.

    User can define a sample_condition function to restrict the generation of configurations.

    The following functions are guaranteed to return valid configurations:
        - self.sample_configuration()
        - get_one_exchange_neighbourhood()  # may return empty list

    Example
    -------

    >>> def sample_condition(config):
    >>>     # require x1 <= x2 and x1 * x2 < 100
    >>>     if config['x1'] > config['x2']:
    >>>         return False
    >>>     if config['x1'] * config['x2'] >= 100:
    >>>         return False
    >>>     return True
    >>>
    >>> cs = ConditionedConfigurationSpace()
    >>> cs.add_hyperparameters([...])
    >>> cs.set_sample_condition(sample_condition)  # set the sample condition after all hyperparameters are added
    >>> configs = cs.sample_configuration(1000)

    Author: Jhj

    """
    sample_condition: Callable[[Configuration], bool] = None

    def set_sample_condition(self, sample_condition: Callable[[Configuration], bool]):
        """
        The sample_condition function takes a configuration as input and returns a boolean value.
            - If the return value is True, the configuration is valid and will be sampled.
            - If the return value is False, the configuration is invalid and will be rejected.
        This function should be called after all hyperparameters are added to the conditioned space.
        """
        self.sample_condition = sample_condition
        self._check_default_configuration()

    def _check_forbidden(self, vector: np.ndarray) -> None:
        """
        This function is called in Configuration.is_valid_configuration().
            - When Configuration.__init__() is called with values (dict), is_valid_configuration() is called.
            - When Configuration.__init__() is called with vectors (np.ndarray), there will be no validation check.
        This function is also called in get_one_exchange_neighbourhood().
        """
        # check original forbidden clauses first
        super()._check_forbidden(vector)

        if self.sample_condition is not None:
            # Populating a configuration from an array does not check if it is a legal configuration.
            # _check_forbidden() is not called. Otherwise, this would be stuck in an infinite loop.
            config = Configuration(self, vector=vector)
            if not self.sample_condition(config):
                raise ForbiddenValueError('User-defined sample condition is not satisfied.')

    def sample_configuration(self, size: int = 1) -> Union['Configuration', List['Configuration']]:
        """
        In ConfigurationSpace.sample_configuration, configurations are built with vectors (np.ndarray),
            so there will be no validation check and _check_forbidden() will not be called.
            We need to check the sample condition manually.

        Returns a single configuration if size = 1 else a list of Configurations
        """
        if self.sample_condition is None:
            return super().sample_configuration(size=size)

        if not isinstance(size, int):
            raise TypeError('Argument size must be of type int, but is %s'
                            % type(size))
        elif size < 1:
            return []

        error_iteration = 0
        accepted_configurations = []  # type: List['Configuration']
        while len(accepted_configurations) < size:
            missing = size - len(accepted_configurations)

            if missing != size:
                missing = int(1.1 * missing)
            missing += 2

            configurations = super().sample_configuration(size=missing)  # missing > 1, return a list
            configurations = [c for c in configurations if self.sample_condition(c)]
            if len(configurations) > 0:
                accepted_configurations.extend(configurations)
            else:
                error_iteration += 1
                if error_iteration > 1000:
                    raise ForbiddenValueError("Cannot sample valid configuration for %s" % self)

        if size <= 1:
            return accepted_configurations[0]
        else:
            return accepted_configurations[:size]

    def add_hyperparameter(self, *args, **kwargs):
        if self.sample_condition is not None:
            raise ValueError('Please add hyperparameter before setting sample condition.')
        return super().add_hyperparameter(*args, **kwargs)

    def add_hyperparameters(self, *args, **kwargs):
        if self.sample_condition is not None:
            raise ValueError('Please add hyperparameters before setting sample condition.')
        return super().add_hyperparameters(*args, **kwargs)
@eddiebergman
Copy link
Contributor

Hi @jhj0411jhj,

If you're adding code in the future, I highly recommend a PR. I think the scope of this is a lot more in line with the linked PR #280 which adds callable in forbidden relations. We also would not like to add a new kind of class just to enable one feature so this would have to be integrated into the main code itself.

Perhaps you can comment on the linked PR itself and see about getting this integrated there!

@jhj0411jhj
Copy link
Author

The intention of this issue is to provide an extention to others without modifying ConfigSpace code, which is written in cython and is a little hard for users to modify without downloading the source code. I didn't create a PR because I'm not very familiar with cython, and I'm afraid to make the code broken.

My implementation is more like a matter of expediency, rather than a perfect one. Implementing forbidden is a more compatible choice with current APIs, but might be complicated. I'll be happy to see such features to be supported in ConfigSpace.

@eddiebergman
Copy link
Contributor

Okay, I wasnt fully sure of your intentions but thanks for making them clear! Ill leave this open in the meantime but I guess there's no much that needs to be followed up from on our side for now. Thanks for the snippet!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants