66
77# pyre-strict
88
9+ from typing import Any
910from unittest .mock import Mock , patch
1011
1112import numpy as np
12- import numpy .typing as npt
1313from ax .adapter .registry import Generators
14+ from ax .core .types import TParameterization
1415from ax .exceptions .core import UserInputError
1516from ax .generation_strategy .generation_strategy import (
1617 GenerationStep ,
1718 GenerationStrategy ,
1819)
20+ from ax .generators .random .sobol import SobolGenerator
1921from ax .metrics .branin import branin
2022from ax .service .managed_loop import OptimizationLoop , optimize
2123from ax .utils .common .testutils import TestCase
2224from ax .utils .testing .mock import mock_botorch_optimize
25+ from pyre_extensions import assert_is_instance , none_throws
2326
2427
2528def _branin_evaluation_function (
26- # pyre-fixme[2]: Parameter must be annotated.
27- parameterization ,
28- weight = None , # pyre-fixme[2]: Parameter must be annotated.
29- ) -> dict [str , tuple [float | npt .NDArray , float ]]:
29+ parameterization : TParameterization ,
30+ weight : float | None = None ,
31+ ) -> dict [str , tuple [float , float ]]:
3032 if any (param_name not in parameterization .keys () for param_name in ["x1" , "x2" ]):
3133 raise ValueError ("Parametrization does not contain x1 or x2" )
32- x1 , x2 = parameterization ["x1" ], parameterization ["x2" ]
34+ x1 , x2 = float ( parameterization ["x1" ]), float ( parameterization ["x2" ])
3335 return {
34- "branin" : (branin (x1 , x2 ), 0.0 ),
35- "constrained_metric" : (- branin (x1 , x2 ), 0.0 ),
36+ "branin" : (float ( branin (x1 , x2 ) ), 0.0 ),
37+ "constrained_metric" : (float ( - branin (x1 , x2 ) ), 0.0 ),
3638 }
3739
3840
3941def _branin_evaluation_function_v2 (
40- # pyre-fixme[2]: Parameter must be annotated.
41- parameterization ,
42- weight = None , # pyre-fixme[2]: Parameter must be annotated.
43- ) -> tuple [float | npt .NDArray , float ]:
42+ parameterization : TParameterization ,
43+ weight : float | None = None ,
44+ ) -> tuple [float , float ]:
4445 if any (param_name not in parameterization .keys () for param_name in ["x1" , "x2" ]):
4546 raise ValueError ("Parametrization does not contain x1 or x2" )
46- x1 , x2 = parameterization ["x1" ], parameterization ["x2" ]
47- return (branin (x1 , x2 ), 0.0 )
47+ x1 , x2 = float ( parameterization ["x1" ]), float ( parameterization ["x2" ])
48+ return (float ( branin (x1 , x2 ) ), 0.0 )
4849
4950
5051def _branin_evaluation_function_with_unknown_sem (
51- # pyre-fixme[2]: Parameter must be annotated.
52- parameterization ,
53- weight = None , # pyre-fixme[2]: Parameter must be annotated.
54- ) -> tuple [float | npt .NDArray , None ]:
52+ parameterization : TParameterization ,
53+ weight : float | None = None ,
54+ ) -> tuple [float , None ]:
5555 if any (param_name not in parameterization .keys () for param_name in ["x1" , "x2" ]):
5656 raise ValueError ("Parametrization does not contain x1 or x2" )
57- x1 , x2 = parameterization ["x1" ], parameterization ["x2" ]
58- return (branin (x1 , x2 ), None )
57+ x1 , x2 = float ( parameterization ["x1" ]), float ( parameterization ["x2" ])
58+ return (float ( branin (x1 , x2 ) ), None )
5959
6060
6161class TestManagedLoop (TestCase ):
6262 """Check functionality of optimization loop."""
6363
6464 def test_with_evaluation_function_propagates_parameter_constraints (self ) -> None :
65- kwargs = {
65+ kwargs : dict [ str , Any ] = {
6666 "parameters" : [
6767 {
6868 "name" : "x1" ,
@@ -151,9 +151,7 @@ def test_branin_with_active_parameter_constraints(self) -> None:
151151 bp , _ = loop .full_run ().get_best_point ()
152152 self .assertIn ("x1" , bp )
153153 self .assertIn ("x2" , bp )
154- # pyre-fixme[58]: `+` is not supported for operand types `Union[None, bool,
155- # float, int, str]` and `Union[None, bool, float, int, str]`.
156- self .assertLessEqual (bp ["x1" ] + bp ["x2" ], 1.0 + 1e-8 )
154+ self .assertLessEqual (float (bp ["x1" ]) + float (bp ["x2" ]), 1.0 + 1e-8 )
157155 with self .assertRaisesRegex (ValueError , "Optimization is complete" ):
158156 loop .run_trial ()
159157
@@ -241,11 +239,8 @@ def test_branin_batch(self) -> None:
241239 self .assertIn ("x2" , bp )
242240 assert vals is not None
243241 self .assertIn ("branin" , vals [0 ])
244- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
245- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
246- self .assertIn ("branin" , vals [1 ])
247- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
248- self .assertIn ("branin" , vals [1 ]["branin" ])
242+ self .assertIn ("branin" , none_throws (vals [1 ]))
243+ self .assertIn ("branin" , none_throws (vals [1 ])["branin" ])
249244 # Check that all total_trials * arms_per_trial * 2 metrics evaluations
250245 # are present in the dataframe.
251246 self .assertEqual (len (loop .experiment .fetch_data ().df .index ), 12 )
@@ -270,11 +265,8 @@ def test_optimize(self) -> None:
270265 self .assertIn ("x2" , best )
271266 assert vals is not None
272267 self .assertIn ("objective" , vals [0 ])
273- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
274- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
275- self .assertIn ("objective" , vals [1 ])
276- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
277- self .assertIn ("objective" , vals [1 ]["objective" ])
268+ self .assertIn ("objective" , none_throws (vals [1 ]))
269+ self .assertIn ("objective" , none_throws (vals [1 ])["objective" ])
278270
279271 @patch (
280272 "ax.service.managed_loop."
@@ -301,11 +293,8 @@ def test_optimize_with_predictions(self, _) -> None:
301293 self .assertIn ("x2" , best )
302294 assert vals is not None
303295 self .assertIn ("a" , vals [0 ])
304- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
305- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
306- self .assertIn ("a" , vals [1 ])
307- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
308- self .assertIn ("a" , vals [1 ]["a" ])
296+ self .assertIn ("a" , none_throws (vals [1 ]))
297+ self .assertIn ("a" , none_throws (vals [1 ])["a" ])
309298
310299 @mock_botorch_optimize
311300 def test_optimize_unknown_sem (self ) -> None :
@@ -327,11 +316,8 @@ def test_optimize_unknown_sem(self) -> None:
327316 self .assertIn ("x2" , best )
328317 self .assertIsNotNone (vals )
329318 self .assertIn ("objective" , vals [0 ])
330- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
331- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
332- self .assertIn ("objective" , vals [1 ])
333- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
334- self .assertIn ("objective" , vals [1 ]["objective" ])
319+ self .assertIn ("objective" , none_throws (vals [1 ]))
320+ self .assertIn ("objective" , none_throws (vals [1 ])["objective" ])
335321
336322 def test_optimize_propagates_random_seed (self ) -> None :
337323 """Tests optimization as a single call."""
@@ -347,8 +333,8 @@ def test_optimize_propagates_random_seed(self) -> None:
347333 total_trials = 5 ,
348334 random_seed = 12345 ,
349335 )
350- # pyre-fixme[16]: Optional type has no attribute ` model`.
351- self .assertEqual (12345 , model . generator .seed )
336+ generator = assert_is_instance ( none_throws ( model ). generator , SobolGenerator )
337+ self .assertEqual (12345 , generator .seed )
352338
353339 def test_optimize_search_space_exhausted (self ) -> None :
354340 """Tests optimization as a single call."""
@@ -370,11 +356,8 @@ def test_optimize_search_space_exhausted(self) -> None:
370356 self .assertIn ("x2" , best )
371357 self .assertIsNotNone (vals )
372358 self .assertIn ("objective" , vals [0 ])
373- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
374- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
375- self .assertIn ("objective" , vals [1 ])
376- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
377- self .assertIn ("objective" , vals [1 ]["objective" ])
359+ self .assertIn ("objective" , none_throws (vals [1 ]))
360+ self .assertIn ("objective" , none_throws (vals [1 ])["objective" ])
378361
379362 def test_custom_gs (self ) -> None :
380363 """Managed loop with custom generation strategy"""
@@ -432,18 +415,14 @@ def test_optimize_graceful_exit_on_exception(self) -> None:
432415 self .assertIn ("x2" , best )
433416 self .assertIsNotNone (vals )
434417 self .assertIn ("objective" , vals [0 ])
435- # pyre-fixme[6]: For 2nd param expected `Union[Container[typing.Any],
436- # Iterable[typing.Any]]` but got `Optional[Dict[str, Dict[str, float]]]`.
437- self .assertIn ("objective" , vals [1 ])
438- # pyre-fixme[16]: Optional type has no attribute `__getitem__`.
439- self .assertIn ("objective" , vals [1 ]["objective" ])
418+ self .assertIn ("objective" , none_throws (vals [1 ]))
419+ self .assertIn ("objective" , none_throws (vals [1 ])["objective" ])
440420
441421 @patch (
442422 "ax.core.experiment.Experiment.new_trial" ,
443423 side_effect = RuntimeError ("cholesky_cpu error - bad matrix" ),
444424 )
445- # pyre-fixme[3]: Return type must be annotated.
446- def test_annotate_exception (self , _ ):
425+ def test_annotate_exception (self , _ : Mock ) -> None :
447426 strategy0 = GenerationStrategy (
448427 name = "Sobol" ,
449428 steps = [GenerationStep (generator = Generators .SOBOL , num_trials = - 1 )],
0 commit comments