@@ -1251,90 +1251,35 @@ def test_pairwise_preference_generator(self) -> None:
12511251 X = X .expand (2 , * X .shape ), Y = comp_pair_Y .expand (2 , * comp_pair_Y .shape )
12521252 )
12531253
1254- def test_get_transformed_model_gen_args_with_target_point (self ) -> None :
1255- # Test that _get_transformed_model_gen_args correctly processes target_point
1256-
1257- # Setup: create adapter with target arm in optimization config
1258- experiment = get_branin_experiment (with_completed_trial = True )
1259- pruning_target_parameterization = Arm (parameters = {"x1" : - 5.0 , "x2" : 15.0 })
1260- optimization_config = none_throws (
1261- experiment .optimization_config
1262- ).clone_with_args (
1263- pruning_target_parameterization = pruning_target_parameterization
1264- )
1265-
1266- adapter = TorchAdapter (
1267- generator = TorchGenerator (),
1268- experiment = experiment ,
1269- transforms = Cont_X_trans ,
1270- )
1271-
1272- # Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
1273- base_gen_args = adapter ._get_transformed_gen_args (
1274- search_space = experiment .search_space ,
1275- optimization_config = optimization_config ,
1276- pending_observations = {},
1277- )
1278-
1279- search_space_digest , torch_opt_config = adapter ._get_transformed_model_gen_args (
1280- search_space = base_gen_args .search_space ,
1281- pending_observations = base_gen_args .pending_observations ,
1282- fixed_features = base_gen_args .fixed_features ,
1283- optimization_config = base_gen_args .optimization_config ,
1284- )
1285-
1286- # Assert: confirm pruning_target_point is correctly extracted and transformed
1287- self .assertIsNotNone (torch_opt_config .pruning_target_point )
1288- expected_target = torch .tensor ([0.0 , 1.0 ], dtype = torch .double )
1289- torch .testing .assert_close (
1290- torch_opt_config .pruning_target_point , expected_target
1254+ def _test_get_transformed_model_gen_args_target_point (
1255+ self ,
1256+ with_status_quo : bool ,
1257+ pruning_target_params : dict [str , float ] | None ,
1258+ expected_target : torch .Tensor | None ,
1259+ ) -> None :
1260+ experiment = get_branin_experiment (
1261+ with_completed_trial = True ,
1262+ with_status_quo = with_status_quo ,
12911263 )
12921264
1293- def test_get_transformed_model_gen_args_no_target_point (self ) -> None :
1294- # Test that _get_transformed_model_gen_args handles
1295- # pruning_target_parameterization=None correctly
1265+ opt_config = none_throws (experiment .optimization_config )
1266+ if pruning_target_params is not None :
1267+ pruning_target = Arm (parameters = pruning_target_params )
1268+ opt_config = opt_config .clone_with_args (
1269+ pruning_target_parameterization = pruning_target
1270+ )
1271+ elif with_status_quo :
1272+ opt_config = opt_config .clone ()
12961273
1297- # Setup: create adapter without target arm (default case)
1298- experiment = get_branin_experiment (with_completed_trial = True )
12991274 adapter = TorchAdapter (
13001275 generator = TorchGenerator (),
13011276 experiment = experiment ,
13021277 transforms = Cont_X_trans ,
13031278 )
13041279
1305- # Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
13061280 base_gen_args = adapter ._get_transformed_gen_args (
13071281 search_space = experiment .search_space ,
1308- optimization_config = none_throws (experiment .optimization_config ),
1309- pending_observations = {},
1310- )
1311-
1312- search_space_digest , torch_opt_config = adapter ._get_transformed_model_gen_args (
1313- search_space = base_gen_args .search_space ,
1314- pending_observations = base_gen_args .pending_observations ,
1315- fixed_features = base_gen_args .fixed_features ,
1316- optimization_config = base_gen_args .optimization_config ,
1317- )
1318-
1319- # Assert: confirm target_point is None when no pruning_target_parameterization
1320- # is provided
1321- self .assertIsNone (torch_opt_config .pruning_target_point )
1322-
1323- def test_get_transformed_model_gen_args_with_sq_as_target (self ) -> None :
1324- # Test that _get_transformed_model_gen_args correctly processes the status quo
1325- # as the target point
1326- experiment = get_branin_experiment (
1327- with_completed_trial = True , with_status_quo = True
1328- )
1329-
1330- adapter = TorchAdapter (
1331- generator = TorchGenerator (), experiment = experiment , transforms = Cont_X_trans
1332- )
1333- oc = none_throws (experiment .optimization_config ).clone ()
1334- # Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
1335- base_gen_args = adapter ._get_transformed_gen_args (
1336- search_space = experiment .search_space ,
1337- optimization_config = oc ,
1282+ optimization_config = opt_config ,
13381283 pending_observations = {},
13391284 )
13401285
@@ -1345,12 +1290,38 @@ def test_get_transformed_model_gen_args_with_sq_as_target(self) -> None:
13451290 optimization_config = base_gen_args .optimization_config ,
13461291 )
13471292
1348- # Assert: confirm pruning_target_point is correctly extracted and transformed
1349- self .assertIsNotNone (torch_opt_config .pruning_target_point )
1350- expected_target = torch .tensor ([1 / 3.0 , 0.0 ], dtype = torch .double )
1351- torch .testing .assert_close (
1352- torch_opt_config .pruning_target_point , expected_target
1353- )
1293+ if expected_target is None :
1294+ self .assertIsNone (torch_opt_config .pruning_target_point )
1295+ else :
1296+ self .assertIsNotNone (torch_opt_config .pruning_target_point )
1297+ torch .testing .assert_close (
1298+ torch_opt_config .pruning_target_point ,
1299+ expected_target ,
1300+ )
1301+
1302+ def test_get_transformed_model_gen_args_target_point (self ) -> None :
1303+ # Test _get_transformed_model_gen_args with various target point scenarios
1304+ for label , with_status_quo , pruning_target_params , expected_target in [
1305+ (
1306+ "with_target_point" ,
1307+ False ,
1308+ {"x1" : - 5.0 , "x2" : 15.0 },
1309+ torch .tensor ([0.0 , 1.0 ], dtype = torch .double ),
1310+ ),
1311+ ("no_target_point" , False , None , None ),
1312+ (
1313+ "sq_as_target" ,
1314+ True ,
1315+ None ,
1316+ torch .tensor ([1 / 3.0 , 0.0 ], dtype = torch .double ),
1317+ ),
1318+ ]:
1319+ with self .subTest (scenario = label ):
1320+ self ._test_get_transformed_model_gen_args_target_point (
1321+ with_status_quo = with_status_quo ,
1322+ pruning_target_params = pruning_target_params ,
1323+ expected_target = expected_target ,
1324+ )
13541325
13551326 @mock_botorch_optimize
13561327 def test_moo_with_derived_parameter (self ) -> None :
0 commit comments