Skip to content

Commit c2560e3

Browse files
maureverCopilottomasfryda
authored
GH-16676 GLM: Remove offset effects (#16749)
Implement the remove offset effects feature into the GLM algorithm. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tomáš Frýda <tomas.fryda@h2o.ai> Co-authored-by: Claude Code
1 parent d0b0868 commit c2560e3

32 files changed

+4428
-1001
lines changed

h2o-algos/src/main/java/hex/api/MakeGLMModelHandler.java

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -55,46 +55,104 @@ public GLMModelV3 make_model(int version, MakeGLMModelV3 args){
5555
}
5656

5757
public GLMModelV3 make_unrestricted_model(int version, MakeUnrestrictedGLMModelV3 args){
58-
GLMModel model = DKV.getGet(args.model.key());
59-
if(model == null)
60-
throw new IllegalArgumentException("Missing source model " + args.model);
61-
if(model._parms._control_variables == null){
62-
throw new IllegalArgumentException("Source model is not trained with control variables.");
63-
}
64-
Key generatedKey = Key.make(model._key.toString()+"_unrestricted_model");
65-
Key key = args.dest != null ? Key.make(args.dest) : generatedKey;
66-
GLMModel modelContrVars = DKV.getGet(key);
67-
if(modelContrVars != null) {
68-
throw new IllegalArgumentException("Model with "+key+" already exists.");
69-
}
70-
GLMModel.GLMParameters parms = (GLMModel.GLMParameters) model._parms.clone();
71-
GLMModel.GLMParameters inputParms = (GLMModel.GLMParameters) model._input_parms.clone();
72-
GLMModel m = new GLMModel(key, parms,null, model._ymu,
73-
Double.NaN, Double.NaN, -1);
74-
m.setInputParms(inputParms);
75-
m._input_parms._control_variables = null;
76-
m._parms._control_variables = null;
77-
DataInfo dinfo = model.dinfo();
78-
dinfo.setPredictorTransform(TransformType.NONE);
79-
m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._column_types, model._output._domains,
80-
model._output.coefficientNames(), model._output.beta(), model._output._binomial, model._output._multinomial,
81-
model._output._ordinal, null);
82-
ModelMetrics mt = model._output._training_metrics_unrestricted_model;
83-
ModelMetrics mv = model._output._validation_metrics_unrestricted_model;
84-
m._output._training_metrics = mt;
85-
m._output._validation_metrics = mv;
86-
m._output._scoring_history = model._output._scoring_history_unrestricted_model;
87-
m._output._model_summary = model._output._model_summary;
88-
m.resetThreshold(model.defaultThreshold());
89-
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
90-
m._key = key;
91-
92-
DKV.put(key, m);
93-
GLMModelV3 res = new GLMModelV3();
94-
res.fillFromImpl(m);
95-
return res;
58+
MakeDerivedGLMModelV3 newArgs = new MakeDerivedGLMModelV3();
59+
newArgs.model = args.model;
60+
newArgs.dest = args.dest;
61+
newArgs.remove_offset_effects = false;
62+
newArgs.remove_control_variables_effects = false;
63+
return make_derived_model(version, newArgs);
64+
}
65+
66+
public GLMModelV3 make_derived_model(int version, MakeDerivedGLMModelV3 args){
67+
GLMModel model = DKV.getGet(args.model.key());
68+
if (model == null)
69+
throw new IllegalArgumentException("Missing source model " + args.model);
70+
if (model._parms._control_variables == null && !model._parms._remove_offset_effects) {
71+
throw new IllegalArgumentException("Source model is not trained with control variables or remove offset effects.");
72+
}
73+
Key generatedKey;
74+
if ((args.remove_control_variables_effects || args.remove_offset_effects) &&
75+
(model._parms._control_variables == null || !model._parms._remove_offset_effects)) {
76+
throw new IllegalArgumentException("You can set remove_control_variables_effects to true or " +
77+
"remove_offset_effects to true only if control_variables and remove_offset_effects are both set.");
78+
} else if (args.remove_control_variables_effects && args.remove_offset_effects) {
79+
throw new IllegalArgumentException("The remove_control_variables_effects and remove_offset_effects feature " +
80+
"cannot be used together. It produces the same model as the main model.");
81+
} else if (args.remove_offset_effects) {
82+
generatedKey = Key.make(model._key.toString() + "_remove_offset_effects");
83+
} else if (args.remove_control_variables_effects) {
84+
generatedKey = Key.make(model._key.toString() + "_remove_control_variables_effects");
85+
} else {
86+
generatedKey = Key.make(model._key.toString()+"_unrestricted_model");
87+
}
88+
Key key = args.dest != null ? Key.make(args.dest) : generatedKey;
89+
GLMModel modelUnrestricted = DKV.getGet(key);
90+
if (modelUnrestricted != null) {
91+
throw new IllegalArgumentException("Model with "+key+" already exists.");
92+
}
93+
GLMModel.GLMParameters parms = (GLMModel.GLMParameters) model._parms.clone();
94+
GLMModel.GLMParameters inputParms = (GLMModel.GLMParameters) model._input_parms.clone();
95+
GLMModel m = new GLMModel(key, parms,null, model._ymu,
96+
Double.NaN, Double.NaN, -1);
97+
m.setInputParms(inputParms);
98+
if (args.remove_control_variables_effects){
99+
m._input_parms._control_variables = model._parms._control_variables;
100+
m._parms._control_variables = model._parms._control_variables;
101+
m._input_parms._remove_offset_effects = false;
102+
m._parms._remove_offset_effects = false;
103+
} else if(args.remove_offset_effects){
104+
m._input_parms._remove_offset_effects = true;
105+
m._parms._remove_offset_effects = true;
106+
m._input_parms._control_variables = null;
107+
m._parms._control_variables = null;
108+
} else {
109+
m._input_parms._control_variables = null;
110+
m._parms._control_variables = null;
111+
m._input_parms._remove_offset_effects = false;
112+
m._parms._remove_offset_effects = false;
113+
}
114+
DataInfo dinfo = model.dinfo().clone();
115+
dinfo.setPredictorTransform(TransformType.NONE);
116+
m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._column_types, model._output._domains,
117+
model._output.coefficientNames(), model._output.beta(), model._output._binomial, model._output._multinomial,
118+
model._output._ordinal, null);
119+
if (args.remove_control_variables_effects) {
120+
ModelMetrics mt = model._output._training_metrics_restricted_model_contr_vals;
121+
ModelMetrics mv = model._output._validation_metrics_restricted_model_contr_vals;
122+
m._output._training_metrics = mt;
123+
m._output._validation_metrics = mv;
124+
m._output._scoring_history = model._output._scoring_history_restricted_model_contr_vals;
125+
m.resetThreshold(model.defaultThreshold());
126+
m._output._variable_importances = model._output._variable_importances;
127+
m._output.setAndMapControlVariablesNames(model._parms._control_variables);
128+
} else if (args.remove_offset_effects) {
129+
ModelMetrics mt = model._output._training_metrics_restricted_model_ro;
130+
ModelMetrics mv = model._output._validation_metrics_restricted_model_ro;
131+
m._output._training_metrics = mt;
132+
m._output._validation_metrics = mv;
133+
m._output._scoring_history = model._output._scoring_history_restricted_model_ro;
134+
m.resetThreshold(model.defaultThreshold());
135+
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
136+
} else {
137+
ModelMetrics mt = model._output._training_metrics_unrestricted_model;
138+
ModelMetrics mv = model._output._validation_metrics_unrestricted_model;
139+
m._output._training_metrics = mt;
140+
m._output._validation_metrics = mv;
141+
m._output._scoring_history = model._output._scoring_history_unrestricted_model;
142+
m.resetThreshold(model.defaultThreshold());
143+
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
144+
}
145+
m._output._model_summary = model._output._model_summary;
146+
m._key = key;
147+
// setting these flags is important for right scoring
148+
m._useControlVariables = args.remove_control_variables_effects;
149+
m._useRemoveOffsetEffects = args.remove_offset_effects;
150+
151+
DKV.put(key, m);
152+
GLMModelV3 res = new GLMModelV3();
153+
res.fillFromImpl(m);
154+
return res;
96155
}
97-
98156

99157
public GLMRegularizationPathV3 extractRegularizationPath(int v, GLMRegularizationPathV3 args) {
100158
GLMModel model = DKV.getGet(args.model.key());

h2o-algos/src/main/java/hex/api/RegisterAlgos.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ public void registerEndPoints(RestApiContext context) {
6060

6161
context.registerEndpoint("make_unrestricted_model", "POST /3/MakeUnrestrictedGLMModel",
6262
MakeGLMModelHandler.class, "make_unrestricted_model",
63-
"Make unrestricted GLM model based on existing one with control variables enabled.");
63+
"Make unrestricted GLM model based on existing one with control variables or remove offset effects features enabled.");
64+
65+
context.registerEndpoint("make_derived_model", "POST /3/MakeDerivedGLMModel",
66+
MakeGLMModelHandler.class, "make_derived_model",
67+
"Make derived GLM model based on existing one with control variables or remove offset effects features enabled.");
6468

6569
context.registerEndpoint("glm_regularization_path","GET /3/GetGLMRegPath", MakeGLMModelHandler.class, "extractRegularizationPath",
6670
"Get full regularization path");

h2o-algos/src/main/java/hex/glm/ComputationState.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ void copyCheckModel2State(GLMModel model, int[][] _gamColIndices) {
200200
// make sure model._betaCndCheckpoint is of the right size
201201
if (model._betaCndCheckpoint != null) {
202202
if (_activeData._activeCols == null || (_activeData._activeCols.length != model._betaCndCheckpoint.length)) {
203-
double[] betaCndCheckpoint = ArrayUtils.expandAndScatter(model._betaCndCheckpoint, coefLen,
203+
double[] betaCndCheckpoint = modelOutput._submodels[submodelInd].idxs == null
204+
? model._betaCndCheckpoint
205+
: ArrayUtils.expandAndScatter(model._betaCndCheckpoint, coefLen,
204206
modelOutput._submodels[submodelInd].idxs); // expand betaCndCheckpoint out
205207
if (_activeData._activeCols != null) // contract the betaCndCheckpoint to the right activeCol length
206208
betaCndCheckpoint = extractSubRange(betaCndCheckpoint.length, 0, activeData()._activeCols, betaCndCheckpoint);

0 commit comments

Comments
 (0)