Skip to content

Commit

Permalink
Fix small bugs related to nan values in vector passed to pdf (#256)
Browse files Browse the repository at this point in the history
* Enable replacing InCondition and ForbiddenRelation constraints

* Allow nan values in CategoricalHP _pdf function
  • Loading branch information
Marc authored Jan 30, 2023
1 parent 88051eb commit 05ab3da
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
25 changes: 17 additions & 8 deletions ConfigSpace/configuration_space.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1526,12 +1526,12 @@ class ConfigurationSpace(collections.abc.Mapping):
new_child = new_configspace[child_name]
new_parent = new_configspace[parent_name]

if hasattr(condition, 'value'):
condition_arg = getattr(condition, 'value')
substituted_condition = condition_type(child=new_child, parent=new_parent, value=condition_arg)
elif hasattr(condition, 'values'):
if hasattr(condition, 'values'):
condition_arg = getattr(condition, 'values')
substituted_condition = condition_type(child=new_child, parent=new_parent, values=condition_arg)
elif hasattr(condition, 'value'):
condition_arg = getattr(condition, 'value')
substituted_condition = condition_type(child=new_child, parent=new_parent, value=condition_arg)
else:
raise AttributeError(f'Did not find the expected attribute in condition {type(condition)}.')

Expand Down Expand Up @@ -1573,15 +1573,24 @@ class ConfigurationSpace(collections.abc.Mapping):
hyperparameter_name = getattr(forbidden.hyperparameter, 'name')
new_hyperparameter = new_configspace[hyperparameter_name]

if hasattr(forbidden, 'value'):
forbidden_arg = getattr(forbidden, 'value')
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, value=forbidden_arg)
elif hasattr(forbidden, 'values'):
if hasattr(forbidden, 'values'):
forbidden_arg = getattr(forbidden, 'values')
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, values=forbidden_arg)
elif hasattr(forbidden, 'value'):
forbidden_arg = getattr(forbidden, 'value')
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, value=forbidden_arg)
else:
raise AttributeError(f'Did not find the expected attribute in forbidden {type(forbidden)}.')

new_forbiddens.append(substituted_forbidden)
elif isinstance(forbidden, ForbiddenRelation):
forbidden_type = type(forbidden)
left_name = getattr(forbidden.left, 'name')
left_hyperparameter = new_configspace[left_name]
right_name = getattr(forbidden.right, 'name')
right_hyperparameter = new_configspace[right_name]

substituted_forbidden = forbidden_type(left=left_hyperparameter, right=right_hyperparameter)
new_forbiddens.append(substituted_forbidden)
else:
raise TypeError(f'Did not expect the supplied forbidden type {type(forbidden)}.')
Expand Down
6 changes: 6 additions & 0 deletions ConfigSpace/hyperparameters.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2609,7 +2609,13 @@ cdef class CategoricalHyperparameter(Hyperparameter):
Probability density values of the input vector
"""
probs = np.array(self.probabilities)
nan = np.isnan(vector)
if np.any(nan):
# Temporarily pick any valid index to use `vector` as an index for `probs`
vector[nan] = 0
res = np.array(probs[vector.astype(int)])
if np.any(nan):
res[nan] = 0
if res.ndim == 0:
return res.reshape(-1)
return res
Expand Down
37 changes: 34 additions & 3 deletions test/test_configuration_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
BetaIntegerHyperparameter,
OrdinalHyperparameter)
from ConfigSpace.exceptions import ForbiddenValueError
from ConfigSpace.forbidden import ForbiddenEqualsRelation
from ConfigSpace.forbidden import ForbiddenEqualsRelation, ForbiddenLessThanRelation


def byteify(input):
Expand Down Expand Up @@ -919,6 +919,34 @@ def test_substitute_hyperparameters_in_conditions(self):
self.assertEqual(new_conditions[0], test_conditions[0])
self.assertEqual(new_conditions[1], test_conditions[1])

def test_substitute_hyperparameters_in_inconditions(self):
cs1 = ConfigurationSpace()
a = UniformIntegerHyperparameter('a', lower=0, upper=10)
b = UniformFloatHyperparameter('b', lower=1., upper=8., log=False)
cs1.add_hyperparameters([a, b])

cond = InCondition(b, a, [1, 2, 3, 4])
cs1.add_conditions([cond])

cs2 = ConfigurationSpace()
sub_a = UniformIntegerHyperparameter('a', lower=0, upper=10)
sub_b = UniformFloatHyperparameter('b', lower=1., upper=8., log=False)
cs2.add_hyperparameters([sub_a, sub_b])
new_conditions = cs1.substitute_hyperparameters_in_conditions(cs1.get_conditions(), cs2)

test_cond = InCondition(b, a, [1, 2, 3, 4])
cs2.add_conditions([test_cond])
test_conditions = cs2.get_conditions()

self.assertEqual(new_conditions[0], test_conditions[0])
self.assertIsNot(new_conditions[0], test_conditions[0])

self.assertEqual(new_conditions[0].get_parents(), test_conditions[0].get_parents())
self.assertIsNot(new_conditions[0].get_parents(), test_conditions[0].get_parents())

self.assertEqual(new_conditions[0].get_children(), test_conditions[0].get_children())
self.assertIsNot(new_conditions[0].get_children(), test_conditions[0].get_children())

def test_substitute_hyperparameters_in_forbiddens(self):
cs1 = ConfigurationSpace()
orig_hp1 = CategoricalHyperparameter("input1", [0, 1])
Expand All @@ -930,7 +958,8 @@ def test_substitute_hyperparameters_in_forbiddens(self):
forb_2 = ForbiddenEqualsClause(orig_hp2, 1)
forb_3 = ForbiddenEqualsClause(orig_hp3, 10)
forb_4 = ForbiddenAndConjunction(forb_1, forb_2)
cs1.add_forbidden_clauses([forb_3, forb_4])
forb_5 = ForbiddenLessThanRelation(orig_hp1, orig_hp2)
cs1.add_forbidden_clauses([forb_3, forb_4, forb_5])

cs2 = ConfigurationSpace()
sub_hp1 = CategoricalHyperparameter("input1", [0, 1, 2])
Expand All @@ -944,9 +973,11 @@ def test_substitute_hyperparameters_in_forbiddens(self):
test_forb_2 = ForbiddenEqualsClause(sub_hp2, 1)
test_forb_3 = ForbiddenEqualsClause(sub_hp3, 10)
test_forb_4 = ForbiddenAndConjunction(test_forb_1, test_forb_2)
cs2.add_forbidden_clauses([test_forb_3, test_forb_4])
test_forb_5 = ForbiddenLessThanRelation(sub_hp1, sub_hp2)
cs2.add_forbidden_clauses([test_forb_3, test_forb_4, test_forb_5])
test_forbiddens = cs2.get_forbiddens()

self.assertEqual(new_forbiddens[2], test_forbiddens[2])
self.assertEqual(new_forbiddens[1], test_forbiddens[1])
self.assertEqual(new_forbiddens[0], test_forbiddens[0])

Expand Down
9 changes: 8 additions & 1 deletion test/test_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,7 @@ def test_categorical__pdf(self):
point_1 = np.array([0])
point_2 = np.array([1])
array_1 = np.array([1, 0, 2])
nan = np.array([0, np.nan])
self.assertEqual(c1._pdf(point_1)[0], 0.4)
self.assertEqual(c1._pdf(point_2)[0], 0.2)
self.assertAlmostEqual(c2._pdf(point_1)[0], 0.7142857142857143)
Expand All @@ -1956,14 +1957,20 @@ def test_categorical__pdf(self):
for res, exp_res in zip(array_results, expected_results):
self.assertEqual(res, exp_res)

nan_results = c1._pdf(nan)
expected_results = np.array([0.4, 0])
self.assertEqual(nan_results.shape, expected_results.shape)
for res, exp_res in zip(nan_results, expected_results):
self.assertEqual(res, exp_res)

# pdf must take a numpy array
with self.assertRaises(TypeError):
c1._pdf(0.2)
with self.assertRaises(TypeError):
c1._pdf('pdf')
with self.assertRaises(TypeError):
c1._pdf('one')
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
c1._pdf(np.array(['zero']))

def test_categorical_get_max_density(self):
Expand Down

0 comments on commit 05ab3da

Please sign in to comment.