48
48
from botorch .sampling .normal import IIDNormalSampler
49
49
from botorch .utils .sampling import manual_seed , unnormalize
50
50
from botorch .utils .testing import (
51
- _get_max_violation_of_bounds ,
52
- _get_max_violation_of_constraints ,
53
51
BotorchTestCase ,
52
+ get_max_violation_of_bounds ,
53
+ get_max_violation_of_constraints ,
54
54
MockAcquisitionFunction ,
55
55
MockModel ,
56
56
MockPosterior ,
@@ -61,11 +61,11 @@ class TestBoundsAndConstraintCheckers(BotorchTestCase):
61
61
def test_bounds_check (self ) -> None :
62
62
bounds = torch .tensor ([[1 , 2 ], [3 , 4 ]], device = self .device )
63
63
samples = torch .tensor ([[2 , 3 ], [2 , 3.1 ]], device = self .device )[None , :, :]
64
- result = _get_max_violation_of_bounds (samples , bounds )
64
+ result = get_max_violation_of_bounds (samples , bounds )
65
65
self .assertAlmostEqual (result , - 0.9 , delta = 1e-6 )
66
66
67
67
samples = torch .tensor ([[2 , 3 ], [2 , 4.1 ]], device = self .device )[None , :, :]
68
- result = _get_max_violation_of_bounds (samples , bounds )
68
+ result = get_max_violation_of_bounds (samples , bounds )
69
69
self .assertAlmostEqual (result , 0.1 , delta = 1e-6 )
70
70
71
71
def test_constraint_check (self ) -> None :
@@ -77,10 +77,10 @@ def test_constraint_check(self) -> None:
77
77
)
78
78
]
79
79
samples = torch .tensor ([[2 , 3 ], [2 , 3.1 ]], device = self .device )[None , :, :]
80
- result = _get_max_violation_of_constraints (samples , constraints , equality = True )
80
+ result = get_max_violation_of_constraints (samples , constraints , equality = True )
81
81
self .assertAlmostEqual (result , 0.1 , delta = 1e-6 )
82
82
83
- result = _get_max_violation_of_constraints (samples , constraints , equality = False )
83
+ result = get_max_violation_of_constraints (samples , constraints , equality = False )
84
84
self .assertAlmostEqual (result , 0.0 , delta = 1e-6 )
85
85
86
86
@@ -268,7 +268,7 @@ def test_gen_batch_initial_conditions(self):
268
268
self .assertEqual (batch_initial_conditions .device , bounds .device )
269
269
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
270
270
self .assertLess (
271
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ),
271
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ),
272
272
1e-6 ,
273
273
)
274
274
batch_shape = (
@@ -347,7 +347,7 @@ def test_gen_batch_initial_conditions_topn(self):
347
347
self .assertEqual (batch_initial_conditions .device , bounds .device )
348
348
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
349
349
self .assertLess (
350
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ),
350
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ),
351
351
1e-6 ,
352
352
)
353
353
batch_shape = (
@@ -409,7 +409,7 @@ def test_gen_batch_initial_conditions_highdim(self):
409
409
self .assertEqual (batch_initial_conditions .device , bounds .device )
410
410
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
411
411
self .assertLess (
412
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ), 1e-6
412
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ), 1e-6
413
413
)
414
414
if ffs is not None :
415
415
for idx , val in ffs .items ():
@@ -637,18 +637,18 @@ def _to_self_device(
637
637
return None if x is None else x .to (device = self .device )
638
638
639
639
self .assertLess (
640
- _get_max_violation_of_bounds (_to_self_device (samples ), bounds ), tol
640
+ get_max_violation_of_bounds (_to_self_device (samples ), bounds ), tol
641
641
)
642
642
643
643
self .assertLess (
644
- _get_max_violation_of_constraints (
644
+ get_max_violation_of_constraints (
645
645
_to_self_device (samples ), constraints = equalities , equality = True
646
646
),
647
647
tol ,
648
648
)
649
649
650
650
self .assertLess (
651
- _get_max_violation_of_constraints (
651
+ get_max_violation_of_constraints (
652
652
_to_self_device (samples ),
653
653
constraints = inequalities ,
654
654
equality = False ,
@@ -708,19 +708,19 @@ def test_gen_batch_initial_conditions_constraints(self):
708
708
self .assertEqual (batch_initial_conditions .device , bounds .device )
709
709
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
710
710
self .assertLess (
711
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ),
711
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ),
712
712
1e-6 ,
713
713
)
714
714
self .assertLess (
715
- _get_max_violation_of_constraints (
715
+ get_max_violation_of_constraints (
716
716
batch_initial_conditions ,
717
717
inequality_constraints ,
718
718
equality = False ,
719
719
),
720
720
1e-6 ,
721
721
)
722
722
self .assertLess (
723
- _get_max_violation_of_constraints (
723
+ get_max_violation_of_constraints (
724
724
batch_initial_conditions ,
725
725
equality_constraints ,
726
726
equality = True ,
@@ -821,7 +821,7 @@ def test_gen_batch_initial_conditions_interpoint_constraints(self):
821
821
batch_initial_conditions [1 , 2 , 0 ],
822
822
)
823
823
self .assertLess (
824
- _get_max_violation_of_constraints (
824
+ get_max_violation_of_constraints (
825
825
batch_initial_conditions ,
826
826
inequality_constraints ,
827
827
equality = False ,
@@ -886,7 +886,7 @@ def generator(n: int, q: int, seed: int | None):
886
886
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
887
887
self .assertTrue ((batch_initial_conditions [..., - 1 ] == 0.42 ).all ())
888
888
self .assertLess (
889
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ),
889
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ),
890
890
1e-6 ,
891
891
)
892
892
if ffs is not None :
@@ -981,7 +981,7 @@ def test_gen_batch_initial_conditions_fixed_X_fantasies(self):
981
981
self .assertEqual (batch_initial_conditions .device , bounds .device )
982
982
self .assertEqual (batch_initial_conditions .dtype , bounds .dtype )
983
983
self .assertLess (
984
- _get_max_violation_of_bounds (batch_initial_conditions , bounds ),
984
+ get_max_violation_of_bounds (batch_initial_conditions , bounds ),
985
985
1e-6 ,
986
986
)
987
987
batch_shape = (
0 commit comments