Skip to content

Commit 72428e9

Browse files
committed
Add test to catch issue 176
1 parent eb2e318 commit 72428e9

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

tests/linear_tree/test_lt_formulation.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def connect_outputs(mdl):
357357

358358

359359
@pytest.mark.skipif(not lineartree_available, reason="Need Linear-Tree Package")
360-
def test_scaling():
360+
def test_scaling_only_scaler():
361361
mean_x_small = np.mean(X_small)
362362
std_x_small = np.std(X_small)
363363
mean_y_small = np.mean(y_small)
@@ -415,6 +415,65 @@ def connect_outputs(mdl):
415415
)
416416
assert y_pred[0] == pytest.approx((solution_1_bigm[1] - mean_y_small) / std_y_small)
417417

418+
@pytest.mark.skipif(not lineartree_available, reason="Need Linear-Tree Package")
419+
def test_scaling_bounds_and_scaler():
420+
mean_x_small = np.mean(X_small)
421+
std_x_small = np.std(X_small)
422+
mean_y_small = np.mean(y_small)
423+
std_y_small = np.std(y_small)
424+
scaled_x = (X_small - mean_x_small) / std_x_small
425+
scaled_y = (y_small - mean_y_small) / std_y_small
426+
scaled_input_bounds = {0: (np.min(scaled_x), np.max(scaled_x))}
427+
unscaled_input_bounds = {0: (np.min(X_small), np.max(X_small))}
428+
429+
scaler = omlt.scaling.OffsetScaling(
430+
offset_inputs=[mean_x_small],
431+
factor_inputs=[std_x_small],
432+
offset_outputs=[mean_y_small],
433+
factor_outputs=[std_y_small],
434+
)
435+
436+
regr = linear_model_tree(scaled_x, scaled_y)
437+
438+
regr.fit(np.reshape(scaled_x, (-1, 1)), scaled_y)
439+
440+
lt_def2 = LinearTreeDefinition(
441+
regr, scaled_input_bounds=scaled_input_bounds, scaling_object=scaler
442+
)
443+
assert lt_def2.scaled_input_bounds[0][0] == pytest.approx(scaled_input_bounds[0][0])
444+
assert lt_def2.scaled_input_bounds[0][1] == pytest.approx(scaled_input_bounds[0][1])
445+
with pytest.raises(
446+
Exception, match="Input Bounds needed to represent linear trees as MIPs"
447+
):
448+
LinearTreeDefinition(regr)
449+
450+
formulation = LinearTreeHybridBigMFormulation(lt_def2)
451+
452+
model1 = pe.ConcreteModel()
453+
model1.x = pe.Var(initialize=0)
454+
model1.y = pe.Var(initialize=0)
455+
model1.obj = pe.Objective(expr=1)
456+
model1.lt = OmltBlock()
457+
model1.lt.build_formulation(formulation)
458+
459+
@model1.Constraint()
460+
def connect_inputs(mdl):
461+
return mdl.x == mdl.lt.inputs[0]
462+
463+
@model1.Constraint()
464+
def connect_outputs(mdl):
465+
return mdl.y == mdl.lt.outputs[0]
466+
467+
model1.x.fix(0.5)
468+
469+
status_1_bigm = pe.SolverFactory("scip").solve(model1, tee=True)
470+
pe.assert_optimal_termination(status_1_bigm)
471+
solution_1_bigm = (pe.value(model1.x), pe.value(model1.y))
472+
y_pred = regr.predict(
473+
np.array((solution_1_bigm[0] - mean_x_small) / std_x_small).reshape(1, -1)
474+
)
475+
assert y_pred[0] == pytest.approx((solution_1_bigm[1] - mean_y_small) / std_y_small)
476+
418477

419478
#### MULTIVARIATE INPUT TESTING ####
420479

0 commit comments

Comments
 (0)