Skip to content

Commit

Permalink
Specify fidelity_parameters in Ax (#122)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #122

Specify fidelity_dims (which is used by fidelity models) in `fit()`

Put fidelity parameters to the last columns.

Reviewed By: bletham

Differential Revision: D16122299

fbshipit-source-id: a5d3139287e3b3afe87b81424b33a9dbe3ccd699
  • Loading branch information
VilockLi authored and facebook-github-bot committed Jul 12, 2019
1 parent 37a0006 commit e5b1f5b
Show file tree
Hide file tree
Showing 16 changed files with 97 additions and 22 deletions.
14 changes: 12 additions & 2 deletions ax/modelbridge/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def _fit(
) -> None:
# Convert observations to arrays
self.parameters = list(search_space.parameters.keys())
# move fidelity parameters to the last columns
for para in search_space.parameters:
if search_space.parameters[para].is_fidelity:
self.parameters.remove(para)
self.parameters.append(para)
all_metric_names: Set[str] = set()
for od in observation_data:
all_metric_names.update(od.metric_names)
Expand All @@ -60,7 +65,9 @@ def _fit(
)
self.training_in_design = in_design
# Extract bounds and task features
bounds, task_features = get_bounds_and_task(search_space, self.parameters)
bounds, task_features, fidelity_features = get_bounds_and_task(
search_space, self.parameters
)

# Fit
self._model_fit(
Expand All @@ -71,6 +78,7 @@ def _fit(
bounds=bounds,
task_features=task_features,
feature_names=self.parameters,
fidelity_features=fidelity_features,
)

def _model_fit(
Expand All @@ -82,6 +90,7 @@ def _model_fit(
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
"""Fit the model, given numpy types.
"""
Expand All @@ -93,6 +102,7 @@ def _model_fit(
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=fidelity_features,
)

def _update(
Expand Down Expand Up @@ -150,7 +160,7 @@ def _gen(
if not self.parameters: # pragma: no cover
raise ValueError(FIT_MODEL_ERROR.format(action="_gen"))
# Extract bounds
bounds, _ = get_bounds_and_task(search_space, self.parameters)
bounds, _, _ = get_bounds_and_task(search_space, self.parameters)
if optimization_config is None:
raise ValueError(
"ArrayModelBridge requires an OptimizationConfig to be specified"
Expand Down
8 changes: 6 additions & 2 deletions ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def extract_parameter_constraints(

def get_bounds_and_task(
search_space: SearchSpace, param_names: List[str]
) -> Tuple[List[Tuple[float, float]], List[int]]:
) -> Tuple[List[Tuple[float, float]], List[int], List[int]]:
"""Extract box bounds from a search space in the usual Scipy format.
Identify integer parameters as task features.
"""
bounds: List[Tuple[float, float]] = []
task_features: List[int] = []
fidelity_features: List[int] = []
for i, p_name in enumerate(param_names):
p = search_space.parameters[p_name]
# Validation
Expand All @@ -48,7 +49,10 @@ def get_bounds_and_task(
bounds.append((p.lower, p.upper))
if p.parameter_type == ParameterType.INT:
task_features.append(i)
return bounds, task_features
if p.is_fidelity:
fidelity_features.append(i)

return bounds, task_features, fidelity_features


def get_fixed_features(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _gen(
) -> Tuple[List[ObservationFeatures], List[float], Optional[ObservationFeatures]]:
"""Generate new candidates according to a search_space."""
# Extract parameter values
bounds, _ = get_bounds_and_task(search_space, self.parameters)
bounds, _, _ = get_bounds_and_task(search_space, self.parameters)
# Get fixed features
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
# Extract param constraints
Expand Down
26 changes: 18 additions & 8 deletions ax/modelbridge/tests/test_numpy_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class NumpyModelBridgeTest(TestCase):
def setUp(self):
x = RangeParameter("x", ParameterType.FLOAT, lower=0, upper=1)
y = RangeParameter("y", ParameterType.FLOAT, lower=1, upper=2)
y = RangeParameter("y", ParameterType.FLOAT, lower=1, upper=2, is_fidelity=True)
z = RangeParameter("z", ParameterType.FLOAT, lower=0, upper=5)
self.parameters = [x, y, z]
parameter_constraints = [
Expand Down Expand Up @@ -83,16 +83,17 @@ def testFitAndUpdate(self, mock_init):
self.observation_features + [sq_feat],
self.observation_data + [sq_data],
)
self.assertEqual(ma.parameters, ["x", "y", "z"])
self.assertEqual(ma.parameters, ["x", "z", "y"])
self.assertEqual(sorted(ma.outcomes), ["a", "b"])
self.assertEqual(ma.training_in_design, [True, True, True, False])
Xs = {
"a": np.array([[0.2, 1.2, 3.0], [0.4, 1.4, 3.0], [0.6, 1.6, 3]]),
"b": np.array([[0.2, 1.2, 3.0], [0.4, 1.4, 3.0]]),
"a": np.array([[0.2, 3.0, 1.2], [0.4, 3.0, 1.4], [0.6, 3.0, 1.6]]),
"b": np.array([[0.2, 3.0, 1.2], [0.4, 3.0, 1.4]]),
}
Ys = {"a": np.array([[1.0], [2.0], [3.0]]), "b": np.array([[-1.0], [-2.0]])}
Yvars = {"a": np.array([[1.0], [2.0], [3.0]]), "b": np.array([[6.0], [7.0]])}
bounds = [(0.0, 1.0), (1.0, 2.0), (0.0, 5.0)]
# put fidelity parameter to the last column
bounds = [(0.0, 1.0), (0.0, 5.0), (1.0, 2.0)]
model_fit_args = model.fit.mock_calls[0][2]
for i, x in enumerate(model_fit_args["Xs"]):
self.assertTrue(np.array_equal(x, Xs[ma.outcomes[i]]))
Expand All @@ -101,7 +102,7 @@ def testFitAndUpdate(self, mock_init):
for i, v in enumerate(model_fit_args["Yvars"]):
self.assertTrue(np.array_equal(v, Yvars[ma.outcomes[i]]))
self.assertEqual(model_fit_args["bounds"], bounds)
self.assertEqual(model_fit_args["feature_names"], ["x", "y", "z"])
self.assertEqual(model_fit_args["feature_names"], ["x", "z", "y"])

# And update
ma.training_in_design.extend([True, True, True, True])
Expand Down Expand Up @@ -312,15 +313,24 @@ def testCrossValidate(self, mock_init, mock_cv):
self.assertEqual(od, self.observation_data[i])

def testGetBoundsAndTask(self):
bounds, task_features = get_bounds_and_task(self.search_space, ["x", "y", "z"])
bounds, task_features, fidelity_features = get_bounds_and_task(
self.search_space, ["x", "y", "z"]
)
self.assertEqual(bounds, [(0.0, 1.0), (1.0, 2.0), (0.0, 5.0)])
self.assertEqual(task_features, [])
self.assertEqual(fidelity_features, [1])
bounds, task_features, fidelity_features = get_bounds_and_task(
self.search_space, ["x", "z"]
)
self.assertEqual(fidelity_features, [])
# Test that Int param is treated as task feature
search_space = SearchSpace(self.parameters)
search_space._parameters["x"] = RangeParameter(
"x", ParameterType.INT, lower=1, upper=4
)
bounds, task_features = get_bounds_and_task(search_space, ["x", "y", "z"])
bounds, task_features, fidelity_features = get_bounds_and_task(
search_space, ["x", "y", "z"]
)
self.assertEqual(task_features, [0])
# Test validation
search_space._parameters["x"] = ChoiceParameter(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def testTorchModelBridge(self, mock_init):
bounds=None,
feature_names=[],
task_features=[],
fidelity_features=[],
)
model_fit_args = model.fit.mock_calls[0][2]
self.assertTrue(
Expand Down Expand Up @@ -111,7 +112,6 @@ def testTorchModelBridge(self, mock_init):
gen_args = model.gen.mock_calls[0][2]
self.assertEqual(gen_args["n"], 3)
self.assertEqual(gen_args["bounds"], [(0, 1)])
print(gen_args["objective_weights"])
self.assertTrue(
torch.equal(
gen_args["objective_weights"],
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _model_fit(
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
self.model = model
# Convert numpy arrays to torch tensors
Expand All @@ -104,6 +105,7 @@ def _model_fit(
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=fidelity_features,
)

def _model_update(
Expand Down
1 change: 1 addition & 0 deletions ax/models/numpy/randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def fit(
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
for i, X in enumerate(Xs):
self.models.append(
Expand Down
3 changes: 3 additions & 0 deletions ax/models/numpy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def fit(
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
"""Fit model to m outcomes.
Expand All @@ -35,6 +36,8 @@ def fit(
task_features: Columns of X that take integer values and should be
treated as task parameters.
feature_names: Names of each column of X.
fidelity_features: Columns of X that should be treated as fidelity
parameters.
"""
pass

Expand Down
16 changes: 14 additions & 2 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,23 @@ def test_task_feature(self, gp_mock, get_model_mock):
x = [torch.zeros(2, 2)]
y = [torch.zeros(2, 1)]
yvars = [torch.ones(2, 1)]
get_and_fit_model(Xs=x, Ys=y, Yvars=yvars, task_features=[1], state_dict=[])
get_and_fit_model(
Xs=x,
Ys=y,
Yvars=yvars,
task_features=[1],
fidelity_features=[],
state_dict=[],
)
# Check that task feature was correctly passed to _get_model
self.assertEqual(get_model_mock.mock_calls[0][2]["task_feature"], 1)

with self.assertRaises(ValueError):
get_and_fit_model(
Xs=x, Ys=y, Yvars=yvars, task_features=[0, 1], state_dict=[]
Xs=x,
Ys=y,
Yvars=yvars,
task_features=[0, 1],
fidelity_features=[],
state_dict=[],
)
11 changes: 10 additions & 1 deletion ax/models/tests/test_botorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=[],
)
_mock_fit_model.assert_called_once()
# Check attributes
Expand Down Expand Up @@ -81,6 +82,7 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=[],
)
_mock_fit_model.assert_called_once()

Expand Down Expand Up @@ -263,7 +265,12 @@ def test_BotorchModel(self, dtype=torch.float, cuda=False):
key: torch.tensor(val, **tkwargs) for key, val in true_state_dict.items()
}
model = get_and_fit_model(
Xs=Xs1, Ys=Ys1, Yvars=Yvars1, task_features=[], state_dict=true_state_dict
Xs=Xs1,
Ys=Ys1,
Yvars=Yvars1,
task_features=[],
fidelity_features=[],
state_dict=true_state_dict,
)
for k, v in chain(model.named_parameters(), model.named_buffers()):
self.assertTrue(torch.equal(true_state_dict[k], v))
Expand Down Expand Up @@ -292,6 +299,7 @@ def test_BotorchModelOneOutcome(self):
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=[],
)
_mock_fit_model.assert_called_once()
X = torch.rand(2, 3, dtype=torch.float)
Expand Down Expand Up @@ -321,6 +329,7 @@ def test_BotorchModelConstraints(self):
bounds=bounds,
task_features=task_features,
feature_names=feature_names,
fidelity_features=[],
)
_mock_fit_model.assert_called_once()

Expand Down
1 change: 1 addition & 0 deletions ax/models/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def testNumpyModelFit(self):
bounds=[(0, 1)],
task_features=[],
feature_names=["x"],
fidelity_features=[],
)

def testNumpyModelPredict(self):
Expand Down
1 change: 1 addition & 0 deletions ax/models/tests/test_randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def testRFModel(self):
bounds=[(0, 1)] * 2,
task_features=[],
feature_names=["x1", "x2"],
fidelity_features=[],
)
self.assertEqual(len(m.models), 2)
self.assertEqual(len(m.models[0].estimators_), 5)
Expand Down
1 change: 1 addition & 0 deletions ax/models/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def testTorchModelFit(self):
bounds=[(0, 1)],
task_features=[],
feature_names=["x1"],
fidelity_features=[],
)

def testTorchModelPredict(self):
Expand Down
26 changes: 21 additions & 5 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,21 @@ class BotorchModel(TorchModel):
::
model_constructor(
Xs, Ys, Yvars, task_features, state_dict, **kwargs
Xs,
Ys,
Yvars,
task_features,
state_dict,
fidelity_features,
**kwargs
) -> model
Here `Xs`, `Ys`, `Yvars` are lists of tensors (one element per outcome),
`task_features` identifies columns of Xs that should be modeled
as a task, `state_dict` is a pytorch module state dict, and `model` is a
botorch `Model`. Optional kwargs are being passed through from the
`BotorchModel` constructor. This callable is assumed to return a fitted
as a task, `state_dict` is a pytorch module state dict, 'fidelity_features' is
a list of ints that specify the positions of fidelity parameters in 'Xs',
and `model` is a botorch `Model`. Optional kwargs are being passed through
from the `BotorchModel` constructor. This callable is assumed to return a fitted
botorch model that has the same dtype and lives on the same device as the
input tensors.
Expand Down Expand Up @@ -194,6 +201,7 @@ def __init__(
self.dtype = None
self.device = None
self.task_features: List[int] = []
self.fidelity_features: List[int] = []

@copy_doc(TorchModel.fit)
def fit(
Expand All @@ -204,15 +212,21 @@ def fit(
bounds: List[Tuple[float, float]],
task_features: List[int],
feature_names: List[str],
fidelity_features: List[int],
) -> None:
self.dtype = Xs[0].dtype
self.device = Xs[0].device
self.Xs = Xs
self.Ys = Ys
self.Yvars = Yvars
self.task_features = task_features
self.fidelity_features = fidelity_features
self.model = self.model_constructor( # pyre-ignore [28]
Xs=Xs, Ys=Ys, Yvars=Yvars, task_features=self.task_features
Xs=Xs,
Ys=Ys,
Yvars=Yvars,
task_features=self.task_features,
fidelity_features=self.fidelity_features,
)

@copy_doc(TorchModel.predict)
Expand Down Expand Up @@ -352,6 +366,7 @@ def cross_validate(
Yvars=Yvars_train,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
)
return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28]

Expand All @@ -373,4 +388,5 @@ def update(self, Xs: List[Tensor], Ys: List[Tensor], Yvars: List[Tensor]) -> Non
Yvars=self.Yvars,
task_features=self.task_features,
state_dict=state_dict,
fidelity_features=self.fidelity_features,
)
2 changes: 2 additions & 0 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_and_fit_model(
Ys: List[Tensor],
Yvars: List[Tensor],
task_features: List[int],
fidelity_features: List[int],
state_dict: Optional[Dict[str, Tensor]] = None,
**kwargs: Any,
) -> GPyTorchModel:
Expand All @@ -41,6 +42,7 @@ def get_and_fit_model(
Ys: List of Y data, one tensor per outcome
Yvars: List of observed variance of Ys.
task_features: List of columns of X that are tasks.
fidelity_features: List of columns of X that are fidelity parameters.
state_dict: If provided, will set model parameters to this state
dictionary. Otherwise, will fit the model.
Expand Down
Loading

0 comments on commit e5b1f5b

Please sign in to comment.