6
6
7
7
# pyre-strict
8
8
9
- import warnings
10
-
11
9
from ax .core .metric import Metric
12
10
from ax .core .objective import MultiObjective , Objective , ScalarizedObjective
11
+ from ax .exceptions .core import UserInputError
13
12
from ax .utils .common .testutils import TestCase
14
13
15
14
@@ -21,7 +20,7 @@ def setUp(self) -> None:
21
20
"m3" : Metric (name = "m3" , lower_is_better = False ),
22
21
}
23
22
self .objectives = {
24
- "o1" : Objective (metric = self .metrics ["m1" ]),
23
+ "o1" : Objective (metric = self .metrics ["m1" ], minimize = True ),
25
24
"o2" : Objective (metric = self .metrics ["m2" ], minimize = True ),
26
25
"o3" : Objective (metric = self .metrics ["m3" ], minimize = False ),
27
26
}
@@ -38,6 +37,12 @@ def setUp(self) -> None:
38
37
)
39
38
40
39
def test_Init (self ) -> None :
40
+ with self .assertRaisesRegex (UserInputError , "does not specify" ):
41
+ Objective (metric = self .metrics ["m1" ]),
42
+ with self .assertRaisesRegex (
43
+ UserInputError , "doesn't match the specified optimization direction"
44
+ ):
45
+ Objective (metric = self .metrics ["m2" ], minimize = False )
41
46
with self .assertRaises (ValueError ):
42
47
ScalarizedObjective (
43
48
metrics = [self .metrics ["m1" ], self .metrics ["m2" ]], weights = [1.0 ]
@@ -52,20 +57,6 @@ def test_Init(self) -> None:
52
57
metrics = [self .metrics ["m1" ], self .metrics ["m2" ]],
53
58
minimize = False ,
54
59
)
55
- warnings .resetwarnings ()
56
- warnings .simplefilter ("always" , append = True )
57
- with warnings .catch_warnings (record = True ) as ws :
58
- Objective (metric = self .metrics ["m1" ])
59
- self .assertTrue (any (issubclass (w .category , DeprecationWarning ) for w in ws ))
60
- self .assertTrue (
61
- any ("Defaulting to `minimize=False`" in str (w .message ) for w in ws )
62
- )
63
- with warnings .catch_warnings (record = True ) as ws :
64
- Objective (Metric (name = "m4" , lower_is_better = True ), minimize = False )
65
- self .assertTrue (any ("Attempting to maximize" in str (w .message ) for w in ws ))
66
- with warnings .catch_warnings (record = True ) as ws :
67
- Objective (Metric (name = "m4" , lower_is_better = False ), minimize = True )
68
- self .assertTrue (any ("Attempting to minimize" in str (w .message ) for w in ws ))
69
60
self .assertEqual (
70
61
self .objective .get_unconstrainable_metrics (), [self .metrics ["m1" ]]
71
62
)
@@ -77,15 +68,15 @@ def test_MultiObjective(self) -> None:
77
68
78
69
self .assertEqual (self .multi_objective .metrics , list (self .metrics .values ()))
79
70
minimizes = [obj .minimize for obj in self .multi_objective .objectives ]
80
- self .assertEqual (minimizes , [False , True , False ])
71
+ self .assertEqual (minimizes , [True , True , False ])
81
72
weights = [mw [1 ] for mw in self .multi_objective .objective_weights ]
82
73
self .assertEqual (weights , [1.0 , 1.0 , 1.0 ])
83
74
self .assertEqual (self .multi_objective .clone (), self .multi_objective )
84
75
self .assertEqual (
85
76
str (self .multi_objective ),
86
77
(
87
78
"MultiObjective(objectives="
88
- '[Objective(metric_name="m1", minimize=False ), '
79
+ '[Objective(metric_name="m1", minimize=True ), '
89
80
'Objective(metric_name="m2", minimize=True), '
90
81
'Objective(metric_name="m3", minimize=False)])'
91
82
),
@@ -96,19 +87,26 @@ def test_MultiObjective(self) -> None:
96
87
)
97
88
98
89
def test_MultiObjectiveBackwardsCompatibility (self ) -> None :
99
- multi_objective = MultiObjective (
100
- metrics = [self .metrics ["m1" ], self .metrics ["m2" ], self .metrics ["m3" ]]
101
- )
90
+ metrics = [
91
+ Metric (name = "m1" , lower_is_better = False ),
92
+ self .metrics ["m2" ],
93
+ self .metrics ["m3" ],
94
+ ]
95
+ multi_objective = MultiObjective (metrics = metrics )
102
96
minimizes = [obj .minimize for obj in multi_objective .objectives ]
103
- self .assertEqual (multi_objective .metrics , list ( self . metrics . values ()) )
97
+ self .assertEqual (multi_objective .metrics , metrics )
104
98
self .assertEqual (minimizes , [False , True , False ])
105
99
106
100
multi_objective_min = MultiObjective (
107
- metrics = [self .metrics ["m1" ], self .metrics ["m2" ], self .metrics ["m3" ]],
101
+ metrics = [
102
+ Metric (name = "m1" ),
103
+ Metric (name = "m2" ),
104
+ Metric (name = "m3" , lower_is_better = True ),
105
+ ],
108
106
minimize = True ,
109
107
)
110
108
minimizes = [obj .minimize for obj in multi_objective_min .objectives ]
111
- self .assertEqual (minimizes , [True , False , True ])
109
+ self .assertEqual (minimizes , [True , True , True ])
112
110
113
111
def test_ScalarizedObjective (self ) -> None :
114
112
with self .assertRaises (NotImplementedError ):
0 commit comments