@@ -357,7 +357,7 @@ def connect_outputs(mdl):
357357
358358
359359@pytest .mark .skipif (not lineartree_available , reason = "Need Linear-Tree Package" )
360- def test_scaling ():
360+ def test_scaling_only_scaler ():
361361 mean_x_small = np .mean (X_small )
362362 std_x_small = np .std (X_small )
363363 mean_y_small = np .mean (y_small )
@@ -415,6 +415,65 @@ def connect_outputs(mdl):
415415 )
416416 assert y_pred [0 ] == pytest .approx ((solution_1_bigm [1 ] - mean_y_small ) / std_y_small )
417417
418+ @pytest .mark .skipif (not lineartree_available , reason = "Need Linear-Tree Package" )
419+ def test_scaling_bounds_and_scaler ():
420+ mean_x_small = np .mean (X_small )
421+ std_x_small = np .std (X_small )
422+ mean_y_small = np .mean (y_small )
423+ std_y_small = np .std (y_small )
424+ scaled_x = (X_small - mean_x_small ) / std_x_small
425+ scaled_y = (y_small - mean_y_small ) / std_y_small
426+ scaled_input_bounds = {0 : (np .min (scaled_x ), np .max (scaled_x ))}
427+ unscaled_input_bounds = {0 : (np .min (X_small ), np .max (X_small ))}
428+
429+ scaler = omlt .scaling .OffsetScaling (
430+ offset_inputs = [mean_x_small ],
431+ factor_inputs = [std_x_small ],
432+ offset_outputs = [mean_y_small ],
433+ factor_outputs = [std_y_small ],
434+ )
435+
436+ regr = linear_model_tree (scaled_x , scaled_y )
437+
438+ regr .fit (np .reshape (scaled_x , (- 1 , 1 )), scaled_y )
439+
440+ lt_def2 = LinearTreeDefinition (
441+ regr , scaled_input_bounds = scaled_input_bounds , scaling_object = scaler
442+ )
443+ assert lt_def2 .scaled_input_bounds [0 ][0 ] == pytest .approx (scaled_input_bounds [0 ][0 ])
444+ assert lt_def2 .scaled_input_bounds [0 ][1 ] == pytest .approx (scaled_input_bounds [0 ][1 ])
445+ with pytest .raises (
446+ Exception , match = "Input Bounds needed to represent linear trees as MIPs"
447+ ):
448+ LinearTreeDefinition (regr )
449+
450+ formulation = LinearTreeHybridBigMFormulation (lt_def2 )
451+
452+ model1 = pe .ConcreteModel ()
453+ model1 .x = pe .Var (initialize = 0 )
454+ model1 .y = pe .Var (initialize = 0 )
455+ model1 .obj = pe .Objective (expr = 1 )
456+ model1 .lt = OmltBlock ()
457+ model1 .lt .build_formulation (formulation )
458+
459+ @model1 .Constraint ()
460+ def connect_inputs (mdl ):
461+ return mdl .x == mdl .lt .inputs [0 ]
462+
463+ @model1 .Constraint ()
464+ def connect_outputs (mdl ):
465+ return mdl .y == mdl .lt .outputs [0 ]
466+
467+ model1 .x .fix (0.5 )
468+
469+ status_1_bigm = pe .SolverFactory ("scip" ).solve (model1 , tee = True )
470+ pe .assert_optimal_termination (status_1_bigm )
471+ solution_1_bigm = (pe .value (model1 .x ), pe .value (model1 .y ))
472+ y_pred = regr .predict (
473+ np .array ((solution_1_bigm [0 ] - mean_x_small ) / std_x_small ).reshape (1 , - 1 )
474+ )
475+ assert y_pred [0 ] == pytest .approx ((solution_1_bigm [1 ] - mean_y_small ) / std_y_small )
476+
418477
419478#### MULTIVARIATE INPUT TESTING ####
420479
0 commit comments