Skip to content

Commit 8612929

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove unused refit_on_update kwarg
Summary: This has no usage since the `update` method has been deprecated. Removing it to clean up the API. Updated the decoders to make sure we can load previous GS. Reviewed By: Balandat Differential Revision: D56796465 fbshipit-source-id: e759b382a50f8644f76e1ea44f671a672646d977
1 parent 5ae5ceb commit 8612929

File tree

14 files changed

+24
-27
lines changed

14 files changed

+24
-27
lines changed

ax/modelbridge/factory.py

-2
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def get_botorch(
245245
acqf_constructor: TAcqfConstructor = get_qLogNEI,
246246
acqf_optimizer: TOptimizer = scipy_optimizer, # pyre-ignore[9]
247247
refit_on_cv: bool = False,
248-
refit_on_update: bool = True,
249248
optimization_config: Optional[OptimizationConfig] = None,
250249
) -> TorchModelBridge:
251250
"""Instantiates a BotorchModel."""
@@ -266,7 +265,6 @@ def get_botorch(
266265
acqf_constructor=acqf_constructor,
267266
acqf_optimizer=acqf_optimizer,
268267
refit_on_cv=refit_on_cv,
269-
refit_on_update=refit_on_update,
270268
optimization_config=optimization_config,
271269
),
272270
)

ax/modelbridge/tests/test_registry.py

-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def test_enum_sobol_GPEI(self) -> None:
153153
"value": f"{botorch_defaults}.recommend_best_observed_point",
154154
},
155155
"refit_on_cv": False,
156-
"refit_on_update": True,
157156
"warm_start_refitting": True,
158157
"use_input_warping": False,
159158
"use_loocv_pseudo_likelihood": False,

ax/models/tests/test_alebo.py

-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def test_ALEBO(self) -> None:
286286
self.assertTrue(torch.equal(B, m.B))
287287
self.assertEqual(m.laplace_nsamp, 5)
288288
self.assertEqual(m.fit_restarts, 1)
289-
self.assertEqual(m.refit_on_update, True)
290289
self.assertEqual(m.refit_on_cv, False)
291290
self.assertEqual(m.warm_start_refitting, False)
292291

ax/models/torch/alebo.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def alebo_acqf_optimizer(
792792
if base_X_pending is not None
793793
else candidates
794794
)
795-
logger.info(f"Generated sequential candidate {i+1} of {n}")
795+
logger.info(f"Generated sequential candidate {i + 1} of {n}")
796796
if acq_has_X_pend:
797797
acq_function.set_X_pending(base_X_pending)
798798
return candidates, torch.stack(acq_value_list)
@@ -828,7 +828,6 @@ def __init__(
828828
self.laplace_nsamp = laplace_nsamp
829829
self.fit_restarts = fit_restarts
830830
super().__init__(
831-
refit_on_update=True, # Important to not get stuck in local opt.
832831
refit_on_cv=False,
833832
warm_start_refitting=False,
834833
acqf_constructor=ei_or_nei, # pyre-ignore

ax/models/torch/botorch.py

-4
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ class BotorchModel(TorchModel):
125125
signature as described below.
126126
refit_on_cv: If True, refit the model for each fold when performing
127127
cross-validation.
128-
refit_on_update: If True, refit the model after updating the training
129-
data using the `update` method.
130128
warm_start_refitting: If True, start model refitting from previous
131129
model parameters in order to speed up the fitting process.
132130
prior: An optional dictionary that contains the specification of GP model prior.
@@ -251,7 +249,6 @@ def __init__(
251249
acqf_optimizer: TOptimizer = scipy_optimizer,
252250
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
253251
refit_on_cv: bool = False,
254-
refit_on_update: bool = True,
255252
warm_start_refitting: bool = True,
256253
use_input_warping: bool = False,
257254
use_loocv_pseudo_likelihood: bool = False,
@@ -276,7 +273,6 @@ def __init__(
276273
# pyre-fixme[4]: Attribute must be annotated.
277274
self._kwargs = kwargs
278275
self.refit_on_cv = refit_on_cv
279-
self.refit_on_update = refit_on_update
280276
self.warm_start_refitting = warm_start_refitting
281277
self.use_input_warping = use_input_warping
282278
self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood

ax/models/torch/botorch_modular/model.py

-3
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ class BoTorchModel(TorchModel, Base):
140140
based on the data provided.
141141
surrogate: In liu of SurrogateSpecs, an instance of `Surrogate` may be
142142
provided to be used as the sole Surrogate for all outcomes
143-
refit_on_update: Unused.
144143
refit_on_cv: Whether to reoptimize model parameters during call to
145144
`BoTorchmodel.cross_validate`.
146145
warm_start_refit: Whether to load parameters from either the provided
@@ -169,7 +168,6 @@ def __init__(
169168
acquisition_options: Optional[Dict[str, Any]] = None,
170169
botorch_acqf_class: Optional[Type[AcquisitionFunction]] = None,
171170
# TODO: [T168715924] Revisit these "refit" arguments.
172-
refit_on_update: bool = True,
173171
refit_on_cv: bool = False,
174172
warm_start_refit: bool = True,
175173
) -> None:
@@ -216,7 +214,6 @@ def __init__(
216214
self.acquisition_options = acquisition_options or {}
217215
self._botorch_acqf_class = botorch_acqf_class
218216

219-
self.refit_on_update = refit_on_update
220217
self.refit_on_cv = refit_on_cv
221218
self.warm_start_refit = warm_start_refit
222219

ax/models/torch/botorch_moo.py

-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ def __init__(
206206
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
207207
frontier_evaluator: TFrontierEvaluator = pareto_frontier_evaluator,
208208
refit_on_cv: bool = False,
209-
refit_on_update: bool = True,
210209
warm_start_refitting: bool = False,
211210
use_input_warping: bool = False,
212211
use_loocv_pseudo_likelihood: bool = False,
@@ -222,7 +221,6 @@ def __init__(
222221
# pyre-fixme[4]: Attribute must be annotated.
223222
self._kwargs = kwargs
224223
self.refit_on_cv = refit_on_cv
225-
self.refit_on_update = refit_on_update
226224
self.warm_start_refitting = warm_start_refitting
227225
self.use_input_warping = use_input_warping
228226
self.use_loocv_pseudo_likelihood = use_loocv_pseudo_likelihood

ax/models/torch/fully_bayesian.py

-6
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def __init__(
523523
acqf_optimizer: TOptimizer = scipy_optimizer,
524524
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
525525
refit_on_cv: bool = False,
526-
refit_on_update: bool = True,
527526
warm_start_refitting: bool = True,
528527
use_input_warping: bool = False,
529528
# use_saas is deprecated. TODO: remove
@@ -553,8 +552,6 @@ def __init__(
553552
signature as described below.
554553
refit_on_cv: If True, refit the model for each fold when performing
555554
cross-validation.
556-
refit_on_update: If True, refit the model after updating the training
557-
data using the `update` method.
558555
warm_start_refitting: If True, start model refitting from previous
559556
model parameters in order to speed up the fitting process.
560557
use_input_warping: A boolean indicating whether to use input warping
@@ -581,7 +578,6 @@ def __init__(
581578
acqf_optimizer=acqf_optimizer,
582579
best_point_recommender=best_point_recommender,
583580
refit_on_cv=refit_on_cv,
584-
refit_on_update=refit_on_update,
585581
warm_start_refitting=warm_start_refitting,
586582
use_input_warping=use_input_warping,
587583
num_samples=num_samples,
@@ -619,7 +615,6 @@ def __init__(
619615
best_point_recommender: TBestPointRecommender = recommend_best_observed_point,
620616
frontier_evaluator: TFrontierEvaluator = pareto_frontier_evaluator,
621617
refit_on_cv: bool = False,
622-
refit_on_update: bool = True,
623618
warm_start_refitting: bool = False,
624619
use_input_warping: bool = False,
625620
num_samples: int = 256,
@@ -646,7 +641,6 @@ def __init__(
646641
best_point_recommender=best_point_recommender,
647642
frontier_evaluator=frontier_evaluator,
648643
refit_on_cv=refit_on_cv,
649-
refit_on_update=refit_on_update,
650644
warm_start_refitting=warm_start_refitting,
651645
use_input_warping=use_input_warping,
652646
num_samples=num_samples,

ax/models/torch/tests/test_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def test_init(self) -> None:
188188
self.assertEqual(model.botorch_acqf_class, qExpectedImprovement)
189189

190190
# Check defaults for refitting settings.
191-
self.assertTrue(model.refit_on_update)
192191
self.assertFalse(model.refit_on_cv)
193192
self.assertTrue(model.warm_start_refit)
194193

@@ -197,11 +196,9 @@ def test_init(self) -> None:
197196
surrogate=self.surrogate,
198197
acquisition_class=self.acquisition_class,
199198
acquisition_options=self.acquisition_options,
200-
refit_on_update=False,
201199
refit_on_cv=True,
202200
warm_start_refit=False,
203201
)
204-
self.assertFalse(mdl2.refit_on_update)
205202
self.assertTrue(mdl2.refit_on_cv)
206203
self.assertFalse(mdl2.warm_start_refit)
207204

ax/storage/json_store/decoder.py

+2
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def generation_step_from_json(
683683
generation_step_json
684684
)
685685
kwargs = generation_step_json.pop("model_kwargs", None)
686+
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
686687
gen_kwargs = generation_step_json.pop("model_gen_kwargs", None)
687688
completion_criteria = (
688689
object_from_json(
@@ -741,6 +742,7 @@ def model_spec_from_json(
741742
) -> ModelSpec:
742743
"""Load ModelSpec from JSON."""
743744
kwargs = model_spec_json.pop("model_kwargs", None)
745+
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
744746
gen_kwargs = model_spec_json.pop("model_gen_kwargs", None)
745747
return ModelSpec(
746748
model_enum=object_from_json(

ax/storage/json_store/encoders.py

-1
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,6 @@ def botorch_model_to_dict(model: BoTorchModel) -> Dict[str, Any]:
579579
model.surrogate_specs if len(model.surrogate_specs) > 0 else None
580580
),
581581
"botorch_acqf_class": model._botorch_acqf_class,
582-
"refit_on_update": model.refit_on_update,
583582
"refit_on_cv": model.refit_on_cv,
584583
"warm_start_refit": model.warm_start_refit,
585584
}

ax/storage/json_store/tests/test_json_store.py

+21
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ax.core.runner import Runner
1919
from ax.exceptions.core import AxStorageWarning
2020
from ax.exceptions.storage import JSONDecodeError, JSONEncodeError
21+
from ax.modelbridge.generation_node import GenerationStep
2122
from ax.modelbridge.generation_strategy import GenerationStrategy
2223
from ax.modelbridge.registry import Models
2324
from ax.storage.json_store.decoder import (
@@ -668,3 +669,23 @@ def test_objective_backwards_compatibility(self) -> None:
668669
self.assertNotEqual(objective, objective_loaded)
669670
self.assertTrue(objective_loaded.minimize)
670671
self.assertTrue(objective_loaded.metric.lower_is_better)
672+
673+
def test_generation_step_backwards_compatibility(self) -> None:
674+
# Test that we can load a generation step with fit_on_update.
675+
json = {
676+
"__type": "GenerationStep",
677+
"model": {"__type": "Models", "name": "BOTORCH_MODULAR"},
678+
"num_trials": 5,
679+
"min_trials_observed": 0,
680+
"completion_criteria": [],
681+
"max_parallelism": None,
682+
"use_update": False,
683+
"enforce_num_trials": True,
684+
"model_kwargs": {"fit_on_update": False, "other_kwarg": 5},
685+
"model_gen_kwargs": {},
686+
"index": -1,
687+
"should_deduplicate": False,
688+
}
689+
generation_step = object_from_json(json)
690+
self.assertIsInstance(generation_step, GenerationStep)
691+
self.assertEqual(generation_step.model_kwargs, {"other_kwarg": 5})

ax/utils/common/constants.py

-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ class Keys(str, Enum):
7373
QMC = "qmc"
7474
RAW_INNER_SAMPLES = "raw_inner_samples"
7575
RAW_SAMPLES = "raw_samples"
76-
REFIT_ON_UPDATE = "refit_on_update"
7776
SAMPLER = "sampler"
7877
SEED_INNER = "seed_inner"
7978
SEQUENTIAL = "sequential"

tutorials/modular_botax.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@
236236
" acquisition_class=None,\n",
237237
" # Less common model settings shown with default values, refer\n",
238238
" # to `BoTorchModel` documentation for detail\n",
239-
" refit_on_update=True,\n",
240239
" refit_on_cv=False,\n",
241240
" warm_start_refit=True,\n",
242241
")"

0 commit comments

Comments
 (0)