|
| 1 | +package hex.glm; |
| 2 | + |
| 3 | +import hex.CreateFrame; |
| 4 | +import hex.SplitFrame; |
| 5 | +import hex.api.MakeGLMModelHandler; |
| 6 | +import hex.glm.GLMModel.GLMParameters; |
| 7 | +import hex.glm.GLMModel.GLMParameters.Family; |
| 8 | +import hex.schemas.MakeUnrestrictedGLMModelV3; |
| 9 | +import org.junit.Assert; |
| 10 | +import org.junit.Test; |
| 11 | +import org.junit.runner.RunWith; |
| 12 | +import water.*; |
| 13 | +import water.api.schemas3.KeyV3; |
| 14 | +import water.fvec.Frame; |
| 15 | +import water.fvec.Vec; |
| 16 | +import water.runner.CloudSize; |
| 17 | +import water.runner.H2ORunner; |
| 18 | + |
| 19 | +import java.util.Random; |
| 20 | + |
| 21 | +/** |
| 22 | + * Tests MOJO/POJO scoring for GLM models with control variables across all supported families. |
| 23 | + * For each family, verifies that both restricted and unrestricted models produce correct |
| 24 | + * MOJO/POJO predictions and that their predictions differ from each other. |
| 25 | + */ |
| 26 | +@RunWith(H2ORunner.class) |
| 27 | +@CloudSize(1) |
| 28 | +public class GLMPredMojoControlVariablesTest extends TestUtil { |
| 29 | + |
| 30 | + double _tol = 1e-6; |
| 31 | + |
| 32 | + @Test |
| 33 | + public void testGaussianPredMojoControlVariables() { |
| 34 | + try { |
| 35 | + Scope.enter(); |
| 36 | + Frame[] data = createAndSplitData(1, true, false, "gaussian"); |
| 37 | + GLMParameters params = createCommonParams(Family.gaussian, data[0]); |
| 38 | + checkMojoRestrictedAndUnrestricted(params, data[1], "gaussian"); |
| 39 | + } finally { |
| 40 | + Scope.exit(); |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + @Test |
| 45 | + public void testPoissonPredMojoControlVariables() { |
| 46 | + try { |
| 47 | + Scope.enter(); |
| 48 | + Frame[] data = createAndSplitData(20, false, true, "poisson"); |
| 49 | + GLMParameters params = createCommonParams(Family.poisson, data[0]); |
| 50 | + checkMojoRestrictedAndUnrestricted(params, data[1], "poisson"); |
| 51 | + } finally { |
| 52 | + Scope.exit(); |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + @Test |
| 57 | + public void testGammaPredMojoControlVariables() { |
| 58 | + try { |
| 59 | + Scope.enter(); |
| 60 | + Frame[] data = createAndSplitData(1, true, false, "gamma"); |
| 61 | + GLMParameters params = createCommonParams(Family.gamma, data[0]); |
| 62 | + checkMojoRestrictedAndUnrestricted(params, data[1], "gamma"); |
| 63 | + } finally { |
| 64 | + Scope.exit(); |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + @Test |
| 69 | + public void testTweediePredMojoControlVariables() { |
| 70 | + try { |
| 71 | + Scope.enter(); |
| 72 | + Frame[] data = createAndSplitData(1, true, false, "tweedie"); |
| 73 | + GLMParameters params = createCommonParams(Family.tweedie, data[0]); |
| 74 | + params._tweedie_variance_power = 1.5; |
| 75 | + params._tweedie_link_power = 1 - 1.5; |
| 76 | + checkMojoRestrictedAndUnrestricted(params, data[1], "tweedie"); |
| 77 | + } finally { |
| 78 | + Scope.exit(); |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + @Test |
| 83 | + public void testNegativeBinomialPredMojoControlVariables() { |
| 84 | + try { |
| 85 | + Scope.enter(); |
| 86 | + Frame[] data = createAndSplitData(20, false, true, "negativebinomial"); |
| 87 | + GLMParameters params = createCommonParams(Family.negativebinomial, data[0]); |
| 88 | + params._theta = 0.5; |
| 89 | + checkMojoRestrictedAndUnrestricted(params, data[1], "negativebinomial"); |
| 90 | + } finally { |
| 91 | + Scope.exit(); |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + @Test |
| 96 | + public void testFractionalBinomialPredMojoControlVariables() { |
| 97 | + try { |
| 98 | + Scope.enter(); |
| 99 | + Frame[] data = createAndSplitData(2, true, true, "fractionalbinomial"); |
| 100 | + GLMParameters params = createCommonParams(Family.fractionalbinomial, data[0]); |
| 101 | + checkMojoRestrictedAndUnrestricted(params, data[1], "fractionalbinomial"); |
| 102 | + } finally { |
| 103 | + Scope.exit(); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + @Test |
| 108 | + public void testQuasibinomialPredMojoControlVariables() { |
| 109 | + try { |
| 110 | + Scope.enter(); |
| 111 | + Frame[] data = createAndSplitData(2, true, true, "quasibinomial"); |
| 112 | + GLMParameters params = createCommonParams(Family.quasibinomial, data[0]); |
| 113 | + checkMojoRestrictedAndUnrestricted(params, data[1], "quasibinomial"); |
| 114 | + } finally { |
| 115 | + Scope.exit(); |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + /** |
| 120 | + * Creates a synthetic dataset and splits it 80/20 into train/test frames. |
| 121 | + */ |
| 122 | + private Frame[] createAndSplitData(int responseFactors, boolean positiveResponse, |
| 123 | + boolean convertResponseToNumeric, String suffix) { |
| 124 | + CreateFrame cf = new CreateFrame(); |
| 125 | + Random generator = new Random(); |
| 126 | + int numRows = generator.nextInt(10000) + 15000 + 200; |
| 127 | + int numCols = generator.nextInt(17) + 3; |
| 128 | + cf.rows = numRows; |
| 129 | + cf.cols = numCols; |
| 130 | + cf.factors = 10; |
| 131 | + cf.has_response = true; |
| 132 | + cf.response_factors = responseFactors; |
| 133 | + cf.positive_response = positiveResponse; |
| 134 | + cf.missing_fraction = 0; |
| 135 | + cf.seed = System.currentTimeMillis(); |
| 136 | + System.out.println("Createframe parameters: rows: " + numRows + " cols:" + numCols + |
| 137 | + " seed: " + cf.seed + " family: " + suffix); |
| 138 | + |
| 139 | + Frame trainData = cf.execImpl().get(); |
| 140 | + if (convertResponseToNumeric) { |
| 141 | + Vec v = trainData.remove("response"); |
| 142 | + trainData.add("response", v.toNumericVec()); |
| 143 | + Scope.track(v); |
| 144 | + DKV.put(trainData); |
| 145 | + } |
| 146 | + Scope.track(trainData); |
| 147 | + |
| 148 | + SplitFrame sf = new SplitFrame(trainData, new double[]{0.8, 0.2}, |
| 149 | + new Key[]{Key.make("train_" + suffix + ".hex"), Key.make("test_" + suffix + ".hex")}); |
| 150 | + sf.exec().get(); |
| 151 | + Key[] ksplits = sf._destination_frames; |
| 152 | + Frame tr = DKV.get(ksplits[0]).get(); |
| 153 | + Frame te = DKV.get(ksplits[1]).get(); |
| 154 | + Scope.track(tr); |
| 155 | + Scope.track(te); |
| 156 | + return new Frame[]{tr, te}; |
| 157 | + } |
| 158 | + |
| 159 | + private GLMParameters createCommonParams(Family family, Frame tr) { |
| 160 | + GLMParameters params = new GLMParameters(family, family.defaultLink, |
| 161 | + new double[]{0}, new double[]{0}, 0, 0); |
| 162 | + params._train = tr._key; |
| 163 | + params._lambda_search = false; |
| 164 | + params._response_column = "response"; |
| 165 | + params._lambda = new double[]{0}; |
| 166 | + params._alpha = new double[]{0.001}; |
| 167 | + params._objective_epsilon = 1e-6; |
| 168 | + params._beta_epsilon = 1e-4; |
| 169 | + params._standardize = false; |
| 170 | + params._control_variables = new String[]{"C1", "C2"}; |
| 171 | + return params; |
| 172 | + } |
| 173 | + |
| 174 | + /** |
| 175 | + * Tests both the restricted (control variables active) and unrestricted models: |
| 176 | + * 1. Restricted model: MOJO/POJO predictions match H2O predictions |
| 177 | + * 2. Unrestricted model: MOJO/POJO predictions match H2O predictions |
| 178 | + * 3. Restricted vs unrestricted predictions differ |
| 179 | + */ |
| 180 | + private void checkMojoRestrictedAndUnrestricted(GLMParameters params, Frame te, String suffix) { |
| 181 | + // Train and check restricted model (control variables zeroed during scoring) |
| 182 | + GLMModel restrictedModel = new GLM(params).trainModel().get(); |
| 183 | + Scope.track_generic(restrictedModel); |
| 184 | + Frame predRestricted = restrictedModel.score(te); |
| 185 | + Scope.track(predRestricted); |
| 186 | + Assert.assertTrue(restrictedModel.haveMojo()); |
| 187 | + Assert.assertTrue(restrictedModel.testJavaScoring(te, predRestricted, _tol)); |
| 188 | + |
| 189 | + // Create and check unrestricted model (control variable betas included in scoring) |
| 190 | + GLMModel unrestrictedModel = createUnrestrictedModel(restrictedModel, suffix); |
| 191 | + Scope.track_generic(unrestrictedModel); |
| 192 | + Frame predUnrestricted = unrestrictedModel.score(te); |
| 193 | + Scope.track(predUnrestricted); |
| 194 | + Assert.assertTrue(unrestrictedModel.haveMojo()); |
| 195 | + Assert.assertTrue(unrestrictedModel.testJavaScoring(te, predUnrestricted, _tol)); |
| 196 | + |
| 197 | + // Restricted and unrestricted predictions must differ |
| 198 | + assertPredictionsDiffer(predRestricted, predUnrestricted); |
| 199 | + } |
| 200 | + |
| 201 | + private GLMModel createUnrestrictedModel(GLMModel model, String suffix) { |
| 202 | + MakeGLMModelHandler handler = new MakeGLMModelHandler(); |
| 203 | + MakeUnrestrictedGLMModelV3 args = new MakeUnrestrictedGLMModelV3(); |
| 204 | + args.model = new KeyV3.ModelKeyV3(model._key); |
| 205 | + args.dest = "unrestricted_" + suffix; |
| 206 | + handler.make_unrestricted_model(3, args); |
| 207 | + return DKV.getGet(Key.make("unrestricted_" + suffix)); |
| 208 | + } |
| 209 | + |
| 210 | + /** |
| 211 | + * Asserts that at least 10% of prediction values differ between two prediction frames. |
| 212 | + * Uses the probability column for classification (vec index 1) and the predict column |
| 213 | + * for regression (vec index 0). |
| 214 | + */ |
| 215 | + private void assertPredictionsDiffer(Frame pred1, Frame pred2) { |
| 216 | + int colIdx = pred1.numCols() > 1 ? 1 : 0; |
| 217 | + long nrows = Math.min(pred1.numRows(), 100); |
| 218 | + int differ = 0; |
| 219 | + for (int i = 0; i < nrows; i++) { |
| 220 | + if (Math.abs(pred1.vec(colIdx).at(i) - pred2.vec(colIdx).at(i)) > 1e-10) differ++; |
| 221 | + } |
| 222 | + Assert.assertTrue("Restricted and unrestricted predictions should differ (only " + |
| 223 | + differ + "/" + nrows + " rows differed)", differ > nrows / 10); |
| 224 | + } |
| 225 | +} |
0 commit comments