Skip to content

Commit 74d85c0

Browse files
authored
Fix output bound calculations (#168)
This PR is to ensure that the bounds on unscaled output vars are calculated from unscaled input bounds. Likewise, it ensures that bounds on scaled output vars are calculated from scaled input bounds **Legal Acknowledgement**\ By contributing to this software project, I agree my contributions are submitted under the BSD license. I represent I am authorized to make the contributions and grant the license. If my employer has rights to intellectual property that includes these contributions, I represent that I have received permission to make contributions and grant the required license on behalf of that employer.
1 parent f85e4e1 commit 74d85c0

File tree

3 files changed

+76
-13
lines changed

3 files changed

+76
-13
lines changed

src/omlt/linear_tree/lt_definition.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self.__model = lt_regressor
5858
self.__scaling_object = scaling_object
5959

60+
is_scaled = True
6061
# Process input bounds to insure scaled input bounds exist for formulations
6162
if scaled_input_bounds is None:
6263
if unscaled_input_bounds is not None and scaling_object is not None:
@@ -75,12 +76,14 @@ def __init__(
7576
# input bounds = unscaled input bounds
7677
elif unscaled_input_bounds is not None and scaling_object is None:
7778
scaled_input_bounds = unscaled_input_bounds
79+
is_scaled = False
7880
elif unscaled_input_bounds is None:
7981
msg = "Input Bounds needed to represent linear trees as MIPs"
8082
raise ValueError(msg)
8183

8284
self.__unscaled_input_bounds = unscaled_input_bounds
8385
self.__scaled_input_bounds = scaled_input_bounds
86+
self.__is_scaled = is_scaled
8487

8588
self.__splits, self.__leaves, self.__thresholds = _parse_tree_data(
8689
lt_regressor, scaled_input_bounds
@@ -99,6 +102,16 @@ def scaled_input_bounds(self):
99102
"""Returns dict containing scaled input bounds."""
100103
return self.__scaled_input_bounds
101104

105+
@property
106+
def unscaled_input_bounds(self):
107+
"""Returns dict containing unscaled input bounds."""
108+
return self.__unscaled_input_bounds
109+
110+
@property
111+
def is_scaled(self):
112+
"""Returns bool indicating whether model is scaled."""
113+
return self.__is_scaled
114+
102115
@property
103116
def splits(self):
104117
"""Returns dict containing split information."""

src/omlt/linear_tree/lt_formulation.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,17 @@ def _build_formulation(self):
9999
self.model_definition.scaled_input_bounds,
100100
)
101101

102+
input_vars = self.block.scaled_inputs
103+
if self.model_definition.is_scaled is True:
104+
output_vars = self.block.scaled_outputs
105+
else:
106+
output_vars = self.block.outputs
107+
102108
_add_gdp_formulation_to_block(
103109
block=self.block,
104110
model_definition=self.model_definition,
105-
input_vars=self.block.scaled_inputs,
106-
output_vars=self.block.scaled_outputs,
111+
input_vars=input_vars,
112+
output_vars=output_vars,
107113
transformation=self.transformation,
108114
epsilon=self.epsilon,
109115
include_leaf_equalities=True,
@@ -181,12 +187,16 @@ def _build_formulation(self):
181187
)
182188

183189
input_vars = self.block.scaled_inputs
190+
if self.model_definition.is_scaled is True:
191+
output_vars = self.block.scaled_outputs
192+
else:
193+
output_vars = self.block.outputs
184194

185195
_add_gdp_formulation_to_block(
186196
block=block,
187197
model_definition=self.model_definition,
188198
input_vars=input_vars,
189-
output_vars=self.block.scaled_outputs,
199+
output_vars=output_vars,
190200
transformation="custom",
191201
epsilon=self.epsilon,
192202
include_leaf_equalities=False,
@@ -285,7 +295,8 @@ def _add_gdp_formulation_to_block( # noqa: PLR0913
285295
(default: True)
286296
"""
287297
leaves = model_definition.leaves
288-
input_bounds = model_definition.scaled_input_bounds
298+
scaled_input_bounds = model_definition.scaled_input_bounds
299+
unscaled_input_bounds = model_definition.unscaled_input_bounds
289300
n_inputs = model_definition.n_inputs
290301

291302
# The set of leaves and the set of features
@@ -295,17 +306,25 @@ def _add_gdp_formulation_to_block( # noqa: PLR0913
295306

296307
# Use the input_bounds and the linear models in the leaves to calculate
297308
# the lower and upper bounds on the output variable. Required for Pyomo.GDP
298-
output_bounds = _build_output_bounds(model_definition, input_bounds)
309+
scaled_output_bounds = _build_output_bounds(model_definition, scaled_input_bounds)
310+
unscaled_output_bounds = _build_output_bounds(
311+
model_definition, unscaled_input_bounds
312+
)
299313

300314
# Ouptuts are automatically scaled based on whether inputs are scaled
301-
block.outputs.setub(output_bounds[1])
302-
block.outputs.setlb(output_bounds[0])
303-
block.scaled_outputs.setub(output_bounds[1])
304-
block.scaled_outputs.setlb(output_bounds[0])
305-
306-
block.intermediate_output = pe.Var(
307-
tree_ids, bounds=(output_bounds[0], output_bounds[1])
308-
)
315+
block.outputs.setub(unscaled_output_bounds[1])
316+
block.outputs.setlb(unscaled_output_bounds[0])
317+
block.scaled_outputs.setub(scaled_output_bounds[1])
318+
block.scaled_outputs.setlb(scaled_output_bounds[0])
319+
320+
if model_definition.is_scaled is True:
321+
block.intermediate_output = pe.Var(
322+
tree_ids, bounds=(scaled_output_bounds[0], scaled_output_bounds[1])
323+
)
324+
else:
325+
block.intermediate_output = pe.Var(
326+
tree_ids, bounds=(unscaled_output_bounds[0], unscaled_output_bounds[1])
327+
)
309328

310329
# Create a disjunct for each leaf containing the bound constraints
311330
# and the linear model expression.

tests/linear_tree/test_lt_formulation.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ def test_linear_tree_model_single_var(): # noqa: C901
100100
splits = ltmodel_small.splits
101101
leaves = ltmodel_small.leaves
102102
thresholds = ltmodel_small.thresholds
103+
is_scaled = ltmodel_small.is_scaled
104+
unscaled_input_bounds = ltmodel_small.unscaled_input_bounds
103105

104106
assert scaled_input_bounds is not None
107+
assert unscaled_input_bounds is not None
108+
assert not is_scaled
105109
assert n_inputs == 1
106110
assert n_outputs == 1
107111
# test for splits
@@ -384,6 +388,33 @@ def test_scaling():
384388
):
385389
LinearTreeDefinition(regr)
386390

391+
formulation = LinearTreeHybridBigMFormulation(lt_def2)
392+
393+
model1 = pe.ConcreteModel()
394+
model1.x = pe.Var(initialize=0)
395+
model1.y = pe.Var(initialize=0)
396+
model1.obj = pe.Objective(expr=1)
397+
model1.lt = OmltBlock()
398+
model1.lt.build_formulation(formulation)
399+
400+
@model1.Constraint()
401+
def connect_inputs(mdl):
402+
return mdl.x == mdl.lt.inputs[0]
403+
404+
@model1.Constraint()
405+
def connect_outputs(mdl):
406+
return mdl.y == mdl.lt.outputs[0]
407+
408+
model1.x.fix(0.5)
409+
410+
status_1_bigm = pe.SolverFactory("scip").solve(model1, tee=True)
411+
pe.assert_optimal_termination(status_1_bigm)
412+
solution_1_bigm = (pe.value(model1.x), pe.value(model1.y))
413+
y_pred = regr.predict(
414+
np.array((solution_1_bigm[0] - mean_x_small) / std_x_small).reshape(1, -1)
415+
)
416+
assert y_pred[0] == pytest.approx((solution_1_bigm[1] - mean_y_small) / std_y_small)
417+
387418

388419
#### MULTIVARIATE INPUT TESTING ####
389420

0 commit comments

Comments
 (0)