Skip to content

Commit 057370d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove deprecated torch_dtype input from Adapters
Summary: Removes deprecated `torch_dtype` argument from `Adapter` constructors. Updates the storage code to discard the deprecated kwargs to avoid potential errors when loading old experiments. Differential Revision: D69994060
1 parent 80f9d77 commit 057370d

File tree

10 files changed

+23
-65
lines changed

10 files changed

+23
-65
lines changed

ax/benchmark/benchmark_test_functions/surrogate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
6666
return torch.tensor(
6767
means,
6868
device=self.surrogate.device,
69-
dtype=self.surrogate.dtype,
69+
dtype=torch.double,
7070
)
7171

7272
@equality_typechecker

ax/modelbridge/factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def get_botorch(
108108
experiment: Experiment,
109109
data: Data,
110110
search_space: SearchSpace | None = None,
111-
dtype: torch.dtype = torch.double,
112111
device: torch.device = DEFAULT_TORCH_DEVICE,
113112
transforms: list[type[Transform]] = Cont_X_trans + Y_trans,
114113
transform_configs: dict[str, TConfig] | None = None,
@@ -127,7 +126,6 @@ def get_botorch(
127126
experiment=experiment,
128127
data=data,
129128
search_space=search_space or experiment.search_space,
130-
torch_dtype=dtype,
131129
torch_device=device,
132130
transforms=transforms,
133131
transform_configs=transform_configs,

ax/modelbridge/map_torch.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363
model: TorchGenerator,
6464
transforms: list[type[Transform]],
6565
transform_configs: dict[str, TConfig] | None = None,
66-
torch_dtype: torch.dtype | None = None,
6766
torch_device: torch.device | None = None,
6867
status_quo_name: str | None = None,
6968
status_quo_features: ObservationFeatures | None = None,
@@ -90,7 +89,6 @@ def __init__(
9089
the reverse order.
9190
transform_configs: A dictionary from transform name to the
9291
transform config dictionary.
93-
torch_dtype: Torch data type.
9492
torch_device: Torch device.
9593
status_quo_name: Name of the status quo arm. Can only be used if
9694
Data has a single set of ObservationFeatures corresponding to
@@ -134,7 +132,6 @@ def __init__(
134132
model=model,
135133
transforms=transforms,
136134
transform_configs=transform_configs,
137-
torch_dtype=torch_dtype,
138135
torch_device=torch_device,
139136
status_quo_name=status_quo_name,
140137
status_quo_features=status_quo_features,

ax/modelbridge/tests/test_registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def test_enum_sobol_legacy_GPEI(self) -> None:
159159
gpei._bridge_kwargs,
160160
{
161161
"transform_configs": None,
162-
"torch_dtype": None,
163162
"torch_device": None,
164163
"status_quo_name": None,
165164
"status_quo_features": None,

ax/modelbridge/tests/test_torch_modelbridge.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,9 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
8888
device=device,
8989
fit_on_init=False,
9090
)
91-
dtype = torch.double
92-
self.assertEqual(model_bridge.dtype, dtype)
9391
self.assertEqual(model_bridge.device, device)
9492
self.assertIsNone(model_bridge._last_observations)
95-
tkwargs: dict[str, Any] = {"dtype": dtype, "device": device}
93+
tkwargs: dict[str, Any] = {"dtype": torch.double, "device": device}
9694
# Test `_fit`.
9795
X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs)
9896
datasets = {
@@ -285,34 +283,6 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
285283
X = model_bridge._transform_observation_features(observation_features=obsf)
286284
self.assertTrue(torch.equal(X, torch.tensor([[1.0, 2.0]], **tkwargs)))
287285

288-
def _test_TorchAdapter_torch_dtype_deprecated(
289-
self, torch_dtype: torch.dtype
290-
) -> None:
291-
search_space = get_search_space_for_range_values(
292-
min=0.0, max=5.0, parameter_names=["x1", "x2", "x3"]
293-
)
294-
model = mock.MagicMock(TorchGenerator, autospec=True, instance=True)
295-
experiment = Experiment(search_space=search_space, name="test")
296-
with self.assertWarnsRegex(
297-
DeprecationWarning,
298-
"The `torch_dtype` argument to `TorchAdapter` is deprecated",
299-
):
300-
TorchAdapter(
301-
experiment=experiment,
302-
search_space=search_space,
303-
data=experiment.lookup_data(),
304-
model=model,
305-
transforms=[],
306-
fit_on_init=False,
307-
torch_dtype=torch_dtype,
308-
)
309-
310-
def test_TorchAdapter_float(self) -> None:
311-
self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float32)
312-
313-
def test_TorchAdapter_float64(self) -> None:
314-
self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float64)
315-
316286
def test_TorchAdapter_cuda(self) -> None:
317287
if torch.cuda.is_available():
318288
self.test_TorchAdapter(device=torch.device("cuda"))

ax/modelbridge/torch.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from copy import deepcopy
1414
from logging import Logger
1515
from typing import Any
16-
from warnings import warn
1716

1817
import numpy as np
1918
import numpy.typing as npt
@@ -108,7 +107,6 @@ def __init__(
108107
model: TorchGenerator,
109108
transforms: list[type[Transform]],
110109
transform_configs: dict[str, TConfig] | None = None,
111-
torch_dtype: torch.dtype | None = None,
112110
torch_device: torch.device | None = None,
113111
status_quo_name: str | None = None,
114112
status_quo_features: ObservationFeatures | None = None,
@@ -121,19 +119,6 @@ def __init__(
121119
default_model_gen_options: TConfig | None = None,
122120
fit_only_completed_map_metrics: bool = True,
123121
) -> None:
124-
# This warning is being added while we are on 0.4.3, so it will be
125-
# released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed
126-
# in the subsequent minor version. It should also be removed from
127-
# `TorchAdapter` subclasses.
128-
if torch_dtype is not None:
129-
warn(
130-
"The `torch_dtype` argument to `TorchAdapter` is deprecated"
131-
" and will be ignored; data will be in double precision.",
132-
DeprecationWarning,
133-
)
134-
135-
# Note: When `torch_dtype` is removed, this attribute can be removed
136-
self.dtype: torch.dtype = torch.double
137122
self.device = torch_device
138123
# pyre-ignore [4]: Attribute `_default_model_gen_options` of class
139124
# `TorchAdapter` must have a type that does not contain `Any`.
@@ -311,7 +296,7 @@ def _array_list_to_tensors(self, arrays: list[npt.NDArray]) -> list[Tensor]:
311296
return [self._array_to_tensor(x) for x in arrays]
312297

313298
def _array_to_tensor(self, array: npt.NDArray | list[float]) -> Tensor:
314-
return torch.as_tensor(array, dtype=self.dtype, device=self.device)
299+
return torch.as_tensor(array, dtype=torch.double, device=self.device)
315300

316301
def _convert_observations(
317302
self,
@@ -367,10 +352,10 @@ def _convert_observations(
367352
raise ValueError(f"Outcome `{outcome}` was not observed.")
368353
X = torch.stack(Xs[outcome], dim=0)
369354
Y = torch.tensor(
370-
Ys[outcome], dtype=self.dtype, device=self.device
355+
Ys[outcome], dtype=torch.double, device=self.device
371356
).unsqueeze(-1)
372357
Yvar = torch.tensor(
373-
Yvars[outcome], dtype=self.dtype, device=self.device
358+
Yvars[outcome], dtype=torch.double, device=self.device
374359
).unsqueeze(-1)
375360
if Yvar.isnan().all():
376361
Yvar = None
@@ -468,13 +453,13 @@ def _cross_validate(
468453
parameters = self.parameters
469454
X_test = torch.tensor(
470455
[[obsf.parameters[p] for p in parameters] for obsf in cv_test_points],
471-
dtype=self.dtype,
456+
dtype=torch.double,
472457
device=self.device,
473458
)
474459
# Use the model to do the cross validation
475460
f_test, cov_test = none_throws(self.model).cross_validate(
476461
datasets=datasets,
477-
X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device),
462+
X_test=torch.as_tensor(X_test, dtype=torch.double, device=self.device),
478463
search_space_digest=search_space_digest,
479464
use_posterior_predictive=use_posterior_predictive,
480465
)
@@ -1040,7 +1025,7 @@ def _extract_observation_data(
10401025
try:
10411026
x = torch.tensor(
10421027
[obsf.parameters[p] for p in parameters],
1043-
dtype=self.dtype,
1028+
dtype=torch.double,
10441029
device=self.device,
10451030
)
10461031
except (KeyError, TypeError):

ax/storage/json_store/decoder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
"ST_MTGP_NEHVI": "ST_MTGP",
8686
}
8787

88+
# Deprecated model kwargs, to be removed from GStep / GNodes.
89+
_DEPRECATED_MODEL_KWARGS: tuple[str, ...] = ("fit_on_update", "torch_dtype")
90+
8891

8992
@dataclass
9093
class RegistryKwargs:
@@ -727,7 +730,9 @@ def generation_step_from_json(
727730
generation_step_json
728731
)
729732
kwargs = generation_step_json.pop("model_kwargs", None)
730-
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
733+
for k in _DEPRECATED_MODEL_KWARGS:
734+
# Remove deprecated kwargs.
735+
kwargs.pop(k, None)
731736
if kwargs is not None:
732737
kwargs = _extract_surrogate_spec_from_surrogate_specs(kwargs)
733738
gen_kwargs = generation_step_json.pop("model_gen_kwargs", None)
@@ -788,7 +793,9 @@ def model_spec_from_json(
788793
) -> GeneratorSpec:
789794
"""Load GeneratorSpec from JSON."""
790795
kwargs = model_spec_json.pop("model_kwargs", None)
791-
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
796+
for k in _DEPRECATED_MODEL_KWARGS:
797+
# Remove deprecated model kwargs.
798+
kwargs.pop(k, None)
792799
if kwargs is not None:
793800
kwargs = _extract_surrogate_spec_from_surrogate_specs(kwargs)
794801
gen_kwargs = model_spec_json.pop("model_gen_kwargs", None)

ax/storage/json_store/tests/test_json_store.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,11 @@ def test_generation_step_backwards_compatibility(self) -> None:
804804
"max_parallelism": None,
805805
"use_update": False,
806806
"enforce_num_trials": True,
807-
"model_kwargs": {"fit_on_update": False, "other_kwarg": 5},
807+
"model_kwargs": {
808+
"fit_on_update": False,
809+
"torch_dtype": torch.double,
810+
"other_kwarg": 5,
811+
},
808812
"model_gen_kwargs": {},
809813
"index": -1,
810814
"should_deduplicate": False,

ax/utils/testing/core_stubs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,7 @@ def get_generator_run() -> GeneratorRun:
19051905
gen_time=5.0,
19061906
model_key="Sobol",
19071907
model_kwargs={"scramble": False, "torch_device": torch.device("cpu")},
1908-
bridge_kwargs={"transforms": Cont_X_trans, "torch_dtype": torch.double},
1908+
bridge_kwargs={"transforms": Cont_X_trans},
19091909
generation_step_index=0,
19101910
candidate_metadata_by_arm_signature={
19111911
a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms

tutorials/multi_task/multi_task.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@
370370
" search_space: Optional[SearchSpace] = None,\n",
371371
" trial_index: Optional[int] = None,\n",
372372
" device: torch.device = torch.device(\"cpu\"),\n",
373-
" dtype: torch.dtype = torch.double,\n",
374373
") -> TorchAdapter:\n",
375374
" \"\"\"Instantiates a Multi-task Gaussian Process (MTGP) model that generates\n",
376375
" points with EI.\n",
@@ -419,7 +418,6 @@
419418
" data=data,\n",
420419
" transforms=transforms,\n",
421420
" transform_configs=transform_configs,\n",
422-
" torch_dtype=dtype,\n",
423421
" torch_device=device,\n",
424422
" status_quo_features=status_quo_features,\n",
425423
" ),\n",

0 commit comments

Comments
 (0)