Skip to content

Commit

Permalink
Remove deprecated torch_dtype input from Adapters
Browse files Browse the repository at this point in the history
Summary: Removes deprecated `torch_dtype` argument from `Adapter` constructors. Updates the storage code to discard the deprecated kwargs to avoid potential errors when loading old experiments.

Differential Revision: D69994060
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 21, 2025
1 parent 80f9d77 commit 057370d
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 65 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_test_functions/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
dtype=torch.double,
)

@equality_typechecker
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def get_botorch(
experiment: Experiment,
data: Data,
search_space: SearchSpace | None = None,
dtype: torch.dtype = torch.double,
device: torch.device = DEFAULT_TORCH_DEVICE,
transforms: list[type[Transform]] = Cont_X_trans + Y_trans,
transform_configs: dict[str, TConfig] | None = None,
Expand All @@ -127,7 +126,6 @@ def get_botorch(
experiment=experiment,
data=data,
search_space=search_space or experiment.search_space,
torch_dtype=dtype,
torch_device=device,
transforms=transforms,
transform_configs=transform_configs,
Expand Down
3 changes: 0 additions & 3 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
model: TorchGenerator,
transforms: list[type[Transform]],
transform_configs: dict[str, TConfig] | None = None,
torch_dtype: torch.dtype | None = None,
torch_device: torch.device | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand All @@ -90,7 +89,6 @@ def __init__(
the reverse order.
transform_configs: A dictionary from transform name to the
transform config dictionary.
torch_dtype: Torch data type.
torch_device: Torch device.
status_quo_name: Name of the status quo arm. Can only be used if
Data has a single set of ObservationFeatures corresponding to
Expand Down Expand Up @@ -134,7 +132,6 @@ def __init__(
model=model,
transforms=transforms,
transform_configs=transform_configs,
torch_dtype=torch_dtype,
torch_device=torch_device,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def test_enum_sobol_legacy_GPEI(self) -> None:
gpei._bridge_kwargs,
{
"transform_configs": None,
"torch_dtype": None,
"torch_device": None,
"status_quo_name": None,
"status_quo_features": None,
Expand Down
32 changes: 1 addition & 31 deletions ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
device=device,
fit_on_init=False,
)
dtype = torch.double
self.assertEqual(model_bridge.dtype, dtype)
self.assertEqual(model_bridge.device, device)
self.assertIsNone(model_bridge._last_observations)
tkwargs: dict[str, Any] = {"dtype": dtype, "device": device}
tkwargs: dict[str, Any] = {"dtype": torch.double, "device": device}
# Test `_fit`.
X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs)
datasets = {
Expand Down Expand Up @@ -285,34 +283,6 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
X = model_bridge._transform_observation_features(observation_features=obsf)
self.assertTrue(torch.equal(X, torch.tensor([[1.0, 2.0]], **tkwargs)))

def _test_TorchAdapter_torch_dtype_deprecated(
self, torch_dtype: torch.dtype
) -> None:
search_space = get_search_space_for_range_values(
min=0.0, max=5.0, parameter_names=["x1", "x2", "x3"]
)
model = mock.MagicMock(TorchGenerator, autospec=True, instance=True)
experiment = Experiment(search_space=search_space, name="test")
with self.assertWarnsRegex(
DeprecationWarning,
"The `torch_dtype` argument to `TorchAdapter` is deprecated",
):
TorchAdapter(
experiment=experiment,
search_space=search_space,
data=experiment.lookup_data(),
model=model,
transforms=[],
fit_on_init=False,
torch_dtype=torch_dtype,
)

def test_TorchAdapter_float(self) -> None:
self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float32)

def test_TorchAdapter_float64(self) -> None:
self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float64)

def test_TorchAdapter_cuda(self) -> None:
if torch.cuda.is_available():
self.test_TorchAdapter(device=torch.device("cuda"))
Expand Down
27 changes: 6 additions & 21 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from copy import deepcopy
from logging import Logger
from typing import Any
from warnings import warn

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -108,7 +107,6 @@ def __init__(
model: TorchGenerator,
transforms: list[type[Transform]],
transform_configs: dict[str, TConfig] | None = None,
torch_dtype: torch.dtype | None = None,
torch_device: torch.device | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
Expand All @@ -121,19 +119,6 @@ def __init__(
default_model_gen_options: TConfig | None = None,
fit_only_completed_map_metrics: bool = True,
) -> None:
# This warning is being added while we are on 0.4.3, so it will be
# released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed
# in the subsequent minor version. It should also be removed from
# `TorchAdapter` subclasses.
if torch_dtype is not None:
warn(
"The `torch_dtype` argument to `TorchAdapter` is deprecated"
" and will be ignored; data will be in double precision.",
DeprecationWarning,
)

# Note: When `torch_dtype` is removed, this attribute can be removed
self.dtype: torch.dtype = torch.double
self.device = torch_device
# pyre-ignore [4]: Attribute `_default_model_gen_options` of class
# `TorchAdapter` must have a type that does not contain `Any`.
Expand Down Expand Up @@ -311,7 +296,7 @@ def _array_list_to_tensors(self, arrays: list[npt.NDArray]) -> list[Tensor]:
return [self._array_to_tensor(x) for x in arrays]

def _array_to_tensor(self, array: npt.NDArray | list[float]) -> Tensor:
return torch.as_tensor(array, dtype=self.dtype, device=self.device)
return torch.as_tensor(array, dtype=torch.double, device=self.device)

def _convert_observations(
self,
Expand Down Expand Up @@ -367,10 +352,10 @@ def _convert_observations(
raise ValueError(f"Outcome `{outcome}` was not observed.")
X = torch.stack(Xs[outcome], dim=0)
Y = torch.tensor(
Ys[outcome], dtype=self.dtype, device=self.device
Ys[outcome], dtype=torch.double, device=self.device
).unsqueeze(-1)
Yvar = torch.tensor(
Yvars[outcome], dtype=self.dtype, device=self.device
Yvars[outcome], dtype=torch.double, device=self.device
).unsqueeze(-1)
if Yvar.isnan().all():
Yvar = None
Expand Down Expand Up @@ -468,13 +453,13 @@ def _cross_validate(
parameters = self.parameters
X_test = torch.tensor(
[[obsf.parameters[p] for p in parameters] for obsf in cv_test_points],
dtype=self.dtype,
dtype=torch.double,
device=self.device,
)
# Use the model to do the cross validation
f_test, cov_test = none_throws(self.model).cross_validate(
datasets=datasets,
X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device),
X_test=torch.as_tensor(X_test, dtype=torch.double, device=self.device),
search_space_digest=search_space_digest,
use_posterior_predictive=use_posterior_predictive,
)
Expand Down Expand Up @@ -1040,7 +1025,7 @@ def _extract_observation_data(
try:
x = torch.tensor(
[obsf.parameters[p] for p in parameters],
dtype=self.dtype,
dtype=torch.double,
device=self.device,
)
except (KeyError, TypeError):
Expand Down
11 changes: 9 additions & 2 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
"ST_MTGP_NEHVI": "ST_MTGP",
}

# Deprecated model kwargs, to be removed from GStep / GNodes.
_DEPRECATED_MODEL_KWARGS: tuple[str, ...] = ("fit_on_update", "torch_dtype")


@dataclass
class RegistryKwargs:
Expand Down Expand Up @@ -727,7 +730,9 @@ def generation_step_from_json(
generation_step_json
)
kwargs = generation_step_json.pop("model_kwargs", None)
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
for k in _DEPRECATED_MODEL_KWARGS:
# Remove deprecated kwargs.
kwargs.pop(k, None)
if kwargs is not None:
kwargs = _extract_surrogate_spec_from_surrogate_specs(kwargs)
gen_kwargs = generation_step_json.pop("model_gen_kwargs", None)
Expand Down Expand Up @@ -788,7 +793,9 @@ def model_spec_from_json(
) -> GeneratorSpec:
"""Load GeneratorSpec from JSON."""
kwargs = model_spec_json.pop("model_kwargs", None)
kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update.
for k in _DEPRECATED_MODEL_KWARGS:
# Remove deprecated model kwargs.
kwargs.pop(k, None)
if kwargs is not None:
kwargs = _extract_surrogate_spec_from_surrogate_specs(kwargs)
gen_kwargs = model_spec_json.pop("model_gen_kwargs", None)
Expand Down
6 changes: 5 additions & 1 deletion ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,11 @@ def test_generation_step_backwards_compatibility(self) -> None:
"max_parallelism": None,
"use_update": False,
"enforce_num_trials": True,
"model_kwargs": {"fit_on_update": False, "other_kwarg": 5},
"model_kwargs": {
"fit_on_update": False,
"torch_dtype": torch.double,
"other_kwarg": 5,
},
"model_gen_kwargs": {},
"index": -1,
"should_deduplicate": False,
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def get_generator_run() -> GeneratorRun:
gen_time=5.0,
model_key="Sobol",
model_kwargs={"scramble": False, "torch_device": torch.device("cpu")},
bridge_kwargs={"transforms": Cont_X_trans, "torch_dtype": torch.double},
bridge_kwargs={"transforms": Cont_X_trans},
generation_step_index=0,
candidate_metadata_by_arm_signature={
a.signature: {"md_key": f"md_val_{a.signature}"} for a in arms
Expand Down
2 changes: 0 additions & 2 deletions tutorials/multi_task/multi_task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@
" search_space: Optional[SearchSpace] = None,\n",
" trial_index: Optional[int] = None,\n",
" device: torch.device = torch.device(\"cpu\"),\n",
" dtype: torch.dtype = torch.double,\n",
") -> TorchAdapter:\n",
" \"\"\"Instantiates a Multi-task Gaussian Process (MTGP) model that generates\n",
" points with EI.\n",
Expand Down Expand Up @@ -419,7 +418,6 @@
" data=data,\n",
" transforms=transforms,\n",
" transform_configs=transform_configs,\n",
" torch_dtype=dtype,\n",
" torch_device=device,\n",
" status_quo_features=status_quo_features,\n",
" ),\n",
Expand Down

0 comments on commit 057370d

Please sign in to comment.