Skip to content

Commit 00d2bd0

Browse files
committed
prototype
1 parent 19a1d76 commit 00d2bd0

5 files changed

Lines changed: 384 additions & 5 deletions

File tree

h2o-algos/src/main/java/hex/glm/GLMModel.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,7 +2245,10 @@ public double score(double[] data) {
22452245
classCtx.add(new CodeGenerator() {
22462246
@Override
22472247
public void generate(JCodeSB out) {
2248-
JCodeGen.toClassWithArray(out, "public static", "BETA", beta_internal()); // "The Coefficients"
2248+
if (_parms._control_variables != null && _parms._control_variables.length > 0)
2249+
JCodeGen.toClassWithArray(out, "public static", "BETA", _output.getControlValBeta(beta_internal().clone())); // "The Coefficients"
2250+
else
2251+
JCodeGen.toClassWithArray(out, "public static", "BETA", beta_internal()); // "The Coefficients"
22492252
JCodeGen.toClassWithArray(out, "static", "NUM_MEANS", _output._dinfo._numNAFill,"Imputed numeric values");
22502253
JCodeGen.toClassWithArray(out, "static", "CAT_MODES", _output._dinfo.catNAFill(),"Imputed categorical values.");
22512254
JCodeGen.toStaticVar(out, "CATOFFS", dinfo()._catOffsets, "Categorical Offsets");
@@ -2409,15 +2412,24 @@ protected ModelMetrics.MetricBuilder scoreMetrics(Frame adaptFrm) {
24092412

24102413
@Override
24112414
public boolean haveMojo() {
2412-
if (_parms._control_variables != null && _parms._control_variables.length>0)
2413-
return false;
2415+
if (_parms._control_variables != null && _parms._control_variables.length > 0)
2416+
return _parms.interactionSpec() == null &&
2417+
!_parms._family.equals(Family.multinomial) &&
2418+
!_parms._family.equals(Family.ordinal) &&
2419+
super.haveMojo();
24142420
if (_parms.interactionSpec() == null)
24152421
return super.haveMojo();
24162422
return false;
24172423
}
24182424

24192425
@Override
24202426
public boolean havePojo() {
2427+
if (_parms._control_variables != null && _parms._control_variables.length > 0)
2428+
return _parms.interactionSpec() == null &&
2429+
_parms._offset_column == null &&
2430+
!_parms._family.equals(Family.multinomial) &&
2431+
!_parms._family.equals(Family.ordinal) &&
2432+
super.haveMojo();
24212433
if (_parms.interactionSpec() == null && _parms._offset_column == null) return super.havePojo();
24222434
else return false;
24232435
}

h2o-algos/src/main/java/hex/glm/GLMMojoWriter.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ protected void writeModelData() throws IOException {
3131
writekv("num_means", model.dinfo().numNAFill());
3232
writekv("cat_modes", model.dinfo().catNAFill());
3333
}
34-
35-
writekv("beta", model.beta_internal());
34+
35+
if (model._parms._control_variables != null && model._parms._control_variables.length > 0)
36+
writekv("beta", model._output.getControlValBeta(model.beta_internal().clone()));
37+
else
38+
writekv("beta", model.beta_internal());
3639

3740
writekv("family", model._parms._family);
3841
writekv("link", model._parms._link);
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package hex.glm;
2+
3+
import hex.CreateFrame;
4+
import hex.SplitFrame;
5+
import hex.api.MakeGLMModelHandler;
6+
import hex.glm.GLMModel.GLMParameters;
7+
import hex.glm.GLMModel.GLMParameters.Family;
8+
import hex.schemas.MakeUnrestrictedGLMModelV3;
9+
import org.junit.Assert;
10+
import org.junit.Test;
11+
import org.junit.runner.RunWith;
12+
import water.*;
13+
import water.api.schemas3.KeyV3;
14+
import water.fvec.Frame;
15+
import water.fvec.Vec;
16+
import water.runner.CloudSize;
17+
import water.runner.H2ORunner;
18+
19+
import java.util.Random;
20+
21+
/**
22+
* Tests MOJO/POJO scoring for GLM models with control variables across all supported families.
23+
* For each family, verifies that both restricted and unrestricted models produce correct
24+
* MOJO/POJO predictions and that their predictions differ from each other.
25+
*/
26+
@RunWith(H2ORunner.class)
27+
@CloudSize(1)
28+
public class GLMPredMojoControlVariablesTest extends TestUtil {
29+
30+
double _tol = 1e-6;
31+
32+
@Test
33+
public void testGaussianPredMojoControlVariables() {
34+
try {
35+
Scope.enter();
36+
Frame[] data = createAndSplitData(1, true, false, "gaussian");
37+
GLMParameters params = createCommonParams(Family.gaussian, data[0]);
38+
checkMojoRestrictedAndUnrestricted(params, data[1], "gaussian");
39+
} finally {
40+
Scope.exit();
41+
}
42+
}
43+
44+
@Test
45+
public void testPoissonPredMojoControlVariables() {
46+
try {
47+
Scope.enter();
48+
Frame[] data = createAndSplitData(20, false, true, "poisson");
49+
GLMParameters params = createCommonParams(Family.poisson, data[0]);
50+
checkMojoRestrictedAndUnrestricted(params, data[1], "poisson");
51+
} finally {
52+
Scope.exit();
53+
}
54+
}
55+
56+
@Test
57+
public void testGammaPredMojoControlVariables() {
58+
try {
59+
Scope.enter();
60+
Frame[] data = createAndSplitData(1, true, false, "gamma");
61+
GLMParameters params = createCommonParams(Family.gamma, data[0]);
62+
checkMojoRestrictedAndUnrestricted(params, data[1], "gamma");
63+
} finally {
64+
Scope.exit();
65+
}
66+
}
67+
68+
@Test
69+
public void testTweediePredMojoControlVariables() {
70+
try {
71+
Scope.enter();
72+
Frame[] data = createAndSplitData(1, true, false, "tweedie");
73+
GLMParameters params = createCommonParams(Family.tweedie, data[0]);
74+
params._tweedie_variance_power = 1.5;
75+
params._tweedie_link_power = 1 - 1.5;
76+
checkMojoRestrictedAndUnrestricted(params, data[1], "tweedie");
77+
} finally {
78+
Scope.exit();
79+
}
80+
}
81+
82+
@Test
83+
public void testNegativeBinomialPredMojoControlVariables() {
84+
try {
85+
Scope.enter();
86+
Frame[] data = createAndSplitData(20, false, true, "negativebinomial");
87+
GLMParameters params = createCommonParams(Family.negativebinomial, data[0]);
88+
params._theta = 0.5;
89+
checkMojoRestrictedAndUnrestricted(params, data[1], "negativebinomial");
90+
} finally {
91+
Scope.exit();
92+
}
93+
}
94+
95+
@Test
96+
public void testFractionalBinomialPredMojoControlVariables() {
97+
try {
98+
Scope.enter();
99+
Frame[] data = createAndSplitData(2, true, true, "fractionalbinomial");
100+
GLMParameters params = createCommonParams(Family.fractionalbinomial, data[0]);
101+
checkMojoRestrictedAndUnrestricted(params, data[1], "fractionalbinomial");
102+
} finally {
103+
Scope.exit();
104+
}
105+
}
106+
107+
@Test
108+
public void testQuasibinomialPredMojoControlVariables() {
109+
try {
110+
Scope.enter();
111+
Frame[] data = createAndSplitData(2, true, true, "quasibinomial");
112+
GLMParameters params = createCommonParams(Family.quasibinomial, data[0]);
113+
checkMojoRestrictedAndUnrestricted(params, data[1], "quasibinomial");
114+
} finally {
115+
Scope.exit();
116+
}
117+
}
118+
119+
/**
120+
* Creates a synthetic dataset and splits it 80/20 into train/test frames.
121+
*/
122+
private Frame[] createAndSplitData(int responseFactors, boolean positiveResponse,
123+
boolean convertResponseToNumeric, String suffix) {
124+
CreateFrame cf = new CreateFrame();
125+
Random generator = new Random();
126+
int numRows = generator.nextInt(10000) + 15000 + 200;
127+
int numCols = generator.nextInt(17) + 3;
128+
cf.rows = numRows;
129+
cf.cols = numCols;
130+
cf.factors = 10;
131+
cf.has_response = true;
132+
cf.response_factors = responseFactors;
133+
cf.positive_response = positiveResponse;
134+
cf.missing_fraction = 0;
135+
cf.seed = System.currentTimeMillis();
136+
System.out.println("Createframe parameters: rows: " + numRows + " cols:" + numCols +
137+
" seed: " + cf.seed + " family: " + suffix);
138+
139+
Frame trainData = cf.execImpl().get();
140+
if (convertResponseToNumeric) {
141+
Vec v = trainData.remove("response");
142+
trainData.add("response", v.toNumericVec());
143+
Scope.track(v);
144+
DKV.put(trainData);
145+
}
146+
Scope.track(trainData);
147+
148+
SplitFrame sf = new SplitFrame(trainData, new double[]{0.8, 0.2},
149+
new Key[]{Key.make("train_" + suffix + ".hex"), Key.make("test_" + suffix + ".hex")});
150+
sf.exec().get();
151+
Key[] ksplits = sf._destination_frames;
152+
Frame tr = DKV.get(ksplits[0]).get();
153+
Frame te = DKV.get(ksplits[1]).get();
154+
Scope.track(tr);
155+
Scope.track(te);
156+
return new Frame[]{tr, te};
157+
}
158+
159+
private GLMParameters createCommonParams(Family family, Frame tr) {
160+
GLMParameters params = new GLMParameters(family, family.defaultLink,
161+
new double[]{0}, new double[]{0}, 0, 0);
162+
params._train = tr._key;
163+
params._lambda_search = false;
164+
params._response_column = "response";
165+
params._lambda = new double[]{0};
166+
params._alpha = new double[]{0.001};
167+
params._objective_epsilon = 1e-6;
168+
params._beta_epsilon = 1e-4;
169+
params._standardize = false;
170+
params._control_variables = new String[]{"C1", "C2"};
171+
return params;
172+
}
173+
174+
/**
175+
* Tests both the restricted (control variables active) and unrestricted models:
176+
* 1. Restricted model: MOJO/POJO predictions match H2O predictions
177+
* 2. Unrestricted model: MOJO/POJO predictions match H2O predictions
178+
* 3. Restricted vs unrestricted predictions differ
179+
*/
180+
private void checkMojoRestrictedAndUnrestricted(GLMParameters params, Frame te, String suffix) {
181+
// Train and check restricted model (control variables zeroed during scoring)
182+
GLMModel restrictedModel = new GLM(params).trainModel().get();
183+
Scope.track_generic(restrictedModel);
184+
Frame predRestricted = restrictedModel.score(te);
185+
Scope.track(predRestricted);
186+
Assert.assertTrue(restrictedModel.haveMojo());
187+
Assert.assertTrue(restrictedModel.testJavaScoring(te, predRestricted, _tol));
188+
189+
// Create and check unrestricted model (control variable betas included in scoring)
190+
GLMModel unrestrictedModel = createUnrestrictedModel(restrictedModel, suffix);
191+
Scope.track_generic(unrestrictedModel);
192+
Frame predUnrestricted = unrestrictedModel.score(te);
193+
Scope.track(predUnrestricted);
194+
Assert.assertTrue(unrestrictedModel.haveMojo());
195+
Assert.assertTrue(unrestrictedModel.testJavaScoring(te, predUnrestricted, _tol));
196+
197+
// Restricted and unrestricted predictions must differ
198+
assertPredictionsDiffer(predRestricted, predUnrestricted);
199+
}
200+
201+
private GLMModel createUnrestrictedModel(GLMModel model, String suffix) {
202+
MakeGLMModelHandler handler = new MakeGLMModelHandler();
203+
MakeUnrestrictedGLMModelV3 args = new MakeUnrestrictedGLMModelV3();
204+
args.model = new KeyV3.ModelKeyV3(model._key);
205+
args.dest = "unrestricted_" + suffix;
206+
handler.make_unrestricted_model(3, args);
207+
return DKV.getGet(Key.make("unrestricted_" + suffix));
208+
}
209+
210+
/**
211+
* Asserts that at least 10% of prediction values differ between two prediction frames.
212+
* Uses the probability column for classification (vec index 1) and the predict column
213+
* for regression (vec index 0).
214+
*/
215+
private void assertPredictionsDiffer(Frame pred1, Frame pred2) {
216+
int colIdx = pred1.numCols() > 1 ? 1 : 0;
217+
long nrows = Math.min(pred1.numRows(), 100);
218+
int differ = 0;
219+
for (int i = 0; i < nrows; i++) {
220+
if (Math.abs(pred1.vec(colIdx).at(i) - pred2.vec(colIdx).at(i)) > 1e-10) differ++;
221+
}
222+
Assert.assertTrue("Restricted and unrestricted predictions should differ (only " +
223+
differ + "/" + nrows + " rows differed)", differ > nrows / 10);
224+
}
225+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import sys, tempfile
2+
sys.path.insert(1, "../../../")
3+
import h2o
4+
from tests import pyunit_utils
5+
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
6+
7+
FAMILIES = [
8+
dict(family="gaussian", data="regression"),
9+
dict(family="binomial", data="binomial"),
10+
dict(family="tweedie", data="regression", tweedie_variance_power=1.5, tweedie_link_power=-0.5),
11+
]
12+
13+
14+
def make_data(dtype):
15+
if dtype == "regression":
16+
train = h2o.create_frame(rows=5000, cols=10, factors=10, has_response=True,
17+
response_factors=1, positive_response=True, missing_fraction=0, seed=1234)
18+
elif dtype == "binomial":
19+
train = h2o.create_frame(rows=5000, cols=10, factors=10, has_response=True,
20+
response_factors=2, missing_fraction=0, seed=1234)
21+
return train.split_frame(ratios=[0.8], seed=1234)
22+
23+
24+
def check_mojo(model, test, pred_h2o):
25+
mojo_path = model.download_mojo(path=tempfile.mkdtemp())
26+
mojo_model = h2o.upload_mojo(mojo_path)
27+
pred_mojo = mojo_model.predict(test.drop("response"))
28+
common_cols = [c for c in pred_h2o.columns if c in pred_mojo.columns]
29+
prob_cols = [c for c in common_cols if c in ("p0", "p1")]
30+
cols = prob_cols if prob_cols else common_cols
31+
pyunit_utils.compare_frames_local(pred_h2o[cols], pred_mojo[cols], prob=1, tol=1e-6)
32+
33+
34+
def run_family_test(spec, standardize):
35+
print("\n--- %s, standardize=%s ---" % (spec["family"], standardize))
36+
train, test = make_data(spec["data"])
37+
extra = {k: v for k, v in spec.items() if k != "data"}
38+
39+
model = H2OGeneralizedLinearEstimator(
40+
lambda_=0, alpha=0.001,
41+
standardize=standardize, control_variables=["C1", "C2"],
42+
**extra
43+
)
44+
model.train(x=[c for c in train.columns if c != "response"], y="response", training_frame=train)
45+
46+
pred_restricted = model.predict(test)
47+
unrestricted = model.make_unrestricted_glm_model()
48+
pred_unrestricted = unrestricted.predict(test)
49+
50+
col = "p0" if "p0" in pred_restricted.columns else "predict"
51+
max_diff = (pred_restricted[col].asnumeric() - pred_unrestricted[col].asnumeric()).abs().max()
52+
assert max_diff > 1e-10, \
53+
"%s (standardize=%s): restricted vs unrestricted max diff = %e" % (spec["family"], standardize, max_diff)
54+
55+
check_mojo(model, test, pred_restricted)
56+
check_mojo(unrestricted, test, pred_unrestricted)
57+
58+
59+
def glm_mojo_control_variables():
60+
for spec in FAMILIES:
61+
for standardize in [False, True]:
62+
run_family_test(spec, standardize)
63+
64+
65+
if __name__ == "__main__":
66+
pyunit_utils.standalone_test(glm_mojo_control_variables)
67+
else:
68+
glm_mojo_control_variables()

0 commit comments

Comments
 (0)