Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/omlt/linear_tree/lt_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,8 @@ def _add_gdp_formulation_to_block( # noqa: PLR0913

# Use the input_bounds and the linear models in the leaves to calculate
# the lower and upper bounds on the output variable. Required for Pyomo.GDP
scaled_output_bounds = _build_output_bounds(model_definition, scaled_input_bounds)
unscaled_output_bounds = _build_output_bounds(
model_definition, unscaled_input_bounds
)

# Ouptuts are automatically scaled based on whether inputs are scaled
block.outputs.setub(unscaled_output_bounds[1])
block.outputs.setlb(unscaled_output_bounds[0])
scaled_output_bounds = _build_output_bounds(model_definition, scaled_input_bounds)
block.scaled_outputs.setub(scaled_output_bounds[1])
block.scaled_outputs.setlb(scaled_output_bounds[0])

Expand All @@ -330,6 +324,11 @@ def _add_gdp_formulation_to_block( # noqa: PLR0913
tree_ids, bounds=(scaled_output_bounds[0], scaled_output_bounds[1])
)
else:
unscaled_output_bounds = _build_output_bounds(
model_definition, unscaled_input_bounds
)
block.outputs.setub(unscaled_output_bounds[1])
block.outputs.setlb(unscaled_output_bounds[0])
block.intermediate_output = pe.Var(
tree_ids, bounds=(unscaled_output_bounds[0], unscaled_output_bounds[1])
)
Expand Down
60 changes: 59 additions & 1 deletion tests/linear_tree/test_lt_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def connect_outputs(mdl):


@pytest.mark.skipif(not lineartree_available, reason="Need Linear-Tree Package")
def test_scaling():
def test_scaling_only_scaler():
mean_x_small = np.mean(X_small)
std_x_small = np.std(X_small)
mean_y_small = np.mean(y_small)
Expand Down Expand Up @@ -415,6 +415,64 @@ def connect_outputs(mdl):
)
assert y_pred[0] == pytest.approx((solution_1_bigm[1] - mean_y_small) / std_y_small)

@pytest.mark.skipif(not lineartree_available, reason="Need Linear-Tree Package")
def test_scaling_bounds_and_scaler():
mean_x_small = np.mean(X_small)
std_x_small = np.std(X_small)
mean_y_small = np.mean(y_small)
std_y_small = np.std(y_small)
scaled_x = (X_small - mean_x_small) / std_x_small
scaled_y = (y_small - mean_y_small) / std_y_small
scaled_input_bounds = {0: (np.min(scaled_x), np.max(scaled_x))}

scaler = omlt.scaling.OffsetScaling(
offset_inputs=[mean_x_small],
factor_inputs=[std_x_small],
offset_outputs=[mean_y_small],
factor_outputs=[std_y_small],
)

regr = linear_model_tree(scaled_x, scaled_y)

regr.fit(np.reshape(scaled_x, (-1, 1)), scaled_y)

lt_def2 = LinearTreeDefinition(
regr, scaled_input_bounds=scaled_input_bounds, scaling_object=scaler
)
assert lt_def2.scaled_input_bounds[0][0] == pytest.approx(scaled_input_bounds[0][0])
assert lt_def2.scaled_input_bounds[0][1] == pytest.approx(scaled_input_bounds[0][1])
with pytest.raises(
Exception, match="Input Bounds needed to represent linear trees as MIPs"
):
LinearTreeDefinition(regr)

formulation = LinearTreeHybridBigMFormulation(lt_def2)

model1 = pe.ConcreteModel()
model1.x = pe.Var(initialize=0)
model1.y = pe.Var(initialize=0)
model1.obj = pe.Objective(expr=1)
model1.lt = OmltBlock()
model1.lt.build_formulation(formulation)

@model1.Constraint()
def connect_inputs(mdl):
return mdl.x == mdl.lt.inputs[0]

@model1.Constraint()
def connect_outputs(mdl):
return mdl.y == mdl.lt.outputs[0]

model1.x.fix(0.5)

status_1_bigm = pe.SolverFactory("scip").solve(model1, tee=True)
pe.assert_optimal_termination(status_1_bigm)
solution_1_bigm = (pe.value(model1.x), pe.value(model1.y))
y_pred = regr.predict(
np.array((solution_1_bigm[0] - mean_x_small) / std_x_small).reshape(1, -1)
)
assert y_pred[0] == pytest.approx((solution_1_bigm[1] - mean_y_small) / std_y_small)


#### MULTIVARIATE INPUT TESTING ####

Expand Down
Loading