diff --git a/docs/scripts/api_generator.py b/docs/scripts/api_generator.py index 038ed9dd..62a9048c 100644 --- a/docs/scripts/api_generator.py +++ b/docs/scripts/api_generator.py @@ -2,6 +2,7 @@ # https://mkdocstrings.github.io/recipes/ """ + from __future__ import annotations import logging diff --git a/docs/scripts/debug_which_page_is_being_rendered.py b/docs/scripts/debug_which_page_is_being_rendered.py index 5f8b642f..efbe40ed 100644 --- a/docs/scripts/debug_which_page_is_being_rendered.py +++ b/docs/scripts/debug_which_page_is_being_rendered.py @@ -3,6 +3,7 @@ This makes it easier to identify which file is being rendered when an error happens. """ + from __future__ import annotations import logging @@ -16,6 +17,7 @@ log = logging.getLogger("mkdocs") + def on_pre_page( page: mkdocs.structure.pages.Page, config: Any, diff --git a/src/ConfigSpace/forbidden.py b/src/ConfigSpace/forbidden.py index 4b6ba0d4..5eb535db 100644 --- a/src/ConfigSpace/forbidden.py +++ b/src/ConfigSpace/forbidden.py @@ -614,13 +614,18 @@ def is_forbidden_vector(self, vector: Array[f64]) -> bool: # Relation is always evaluated against actual value and not vector rep left: f64 = vector[self.vector_ids[0]] # type: ignore right: f64 = vector[self.vector_ids[1]] # type: ignore + if np.isnan(left) or np.isnan(right): + return False return self.left.to_value(left) < self.right.to_value(right) # type: ignore @override def is_forbidden_vector_array(self, arr: Array[f64]) -> Mask: left = arr[self.vector_ids[0]] right = arr[self.vector_ids[1]] - return self.left.to_value(left) < self.right.to_value(right) + valid = ~(np.isnan(left) | np.isnan(right)) + out = np.zeros_like(valid) + out[valid] = self.left.to_value(left[valid]) < self.right.to_value(right[valid]) + return out class ForbiddenEqualsRelation(ForbiddenRelation): @@ -680,13 +685,19 @@ def is_forbidden_vector(self, vector: Array[f64]) -> bool: # Relation is always evaluated against actual value and not vector rep left = vector[self.vector_ids[0]] right = vector[self.vector_ids[1]] + if np.isnan(left) or np.isnan(right): + return False return self.left.to_value(left) == self.right.to_value(right) # type: ignore @override def is_forbidden_vector_array(self, arr: Array[f64]) -> Mask: left = arr[self.vector_ids[0]] right = arr[self.vector_ids[1]] - return self.left.to_value(left) == self.right.to_value(right) # type: ignore + valid = ~(np.isnan(left) | np.isnan(right)) + tmp = self.left.to_value(left[valid]) == self.right.to_value(right[valid]) + out = np.zeros_like(valid) + out[valid] = tmp + return out # type: ignore class ForbiddenGreaterThanRelation(ForbiddenRelation): @@ -745,13 +756,18 @@ def is_forbidden_vector(self, vector: Array[f64]) -> bool: # Relation is always evaluated against actual value and not vector rep left: f64 = vector[self.vector_ids[0]] # type: ignore right: f64 = vector[self.vector_ids[1]] # type: ignore + if np.isnan(left) or np.isnan(right): + return False return self.left.to_value(left) > self.right.to_value(right) # type: ignore @override def is_forbidden_vector_array(self, arr: Array[f64]) -> Mask: left = arr[self.vector_ids[0]] right = arr[self.vector_ids[1]] - return self.left.to_value(left) > self.right.to_value(right) + valid = ~(np.isnan(left) | np.isnan(right)) + out = np.zeros_like(valid) + out[valid] = self.left.to_value(left[valid]) > self.right.to_value(right[valid]) + return out ForbiddenLike = Union[ diff --git a/test/test_forbidden.py b/test/test_forbidden.py index f297a6a8..7f47b881 100644 --- a/test/test_forbidden.py +++ b/test/test_forbidden.py @@ -291,3 +291,23 @@ def test_relation(): assert forb.is_forbidden_value( {"water_temperature": "hot", "water_temperature2": "cold"}, ) + + +def test_relation_conditioned(): + from ConfigSpace import EqualsCondition, ConfigurationSpace + + a = OrdinalHyperparameter("a", [2, 5, 10]) + enable_a = CategoricalHyperparameter("enable_a", [False, True], weights=[99999, 1]) + cond_a = EqualsCondition(a, enable_a, True) + + b = OrdinalHyperparameter("b", [5, 10, 15]) + for forbid in ( + ForbiddenEqualsRelation, + ForbiddenGreaterThanRelation, + ForbiddenLessThanRelation, + ): + forbid_a_b = forbid(a, b) + + cs = ConfigurationSpace() + cs.add([a, enable_a, cond_a, b, forbid_a_b]) + cs.sample_configuration(100)