diff --git a/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoModel.java b/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoModel.java index 3f2f6b30525f..9d402f8ae690 100644 --- a/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoModel.java +++ b/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoModel.java @@ -78,13 +78,27 @@ public double[] score0(double[] row, double[] predictions) { @Override public double[] score0(double[] row, double offset, double[] predictions) { int[] enumOffset = null; - if (_interaction_targets != null) { - enumOffset = evaluateInteractions(row); + + if (_nums == -1) { + predictions[0] = forCategories(row) + forOtherColumns(row) - forStrata(row) + offset; + } else { + if (_interaction_targets != null) { + enumOffset = evaluateInteractions(row); + } + predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset; } - predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset; return predictions; } + private double forOtherColumns(double[] row) { + double result = 0.0; + int catOffsetDiff = _cat_offsets[_cats] - _cats; + for(int i = _cats ; i + catOffsetDiff < _coef.length; i++) { + result += _coef[catOffsetDiff + i] * featureValue(row, i); + } + return result; + } + private double forOtherColumns(double[] row, int[] enumOffset) { double result = 0.0; int coefLen = _coef.length; diff --git a/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoReader.java b/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoReader.java index 3354a240f82a..273a230c68e5 100644 --- a/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoReader.java +++ b/h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoReader.java @@ -28,7 +28,7 @@ protected void readModelData() throws IOException { _model._interaction_targets = readkv("interaction_targets"); _model._interaction_column_index = new HashSet<>(); _model._interaction_column_domains = new HashMap<>(); - _model._nums = readkv("num_numerical_columns"); + _model._nums = readkv("num_numerical_columns", -1); _model._num_offsets = readkv("num_offsets"); if (_model._interaction_targets != null) { diff --git a/h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java b/h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java index 9ce81947836a..31967125a550 100644 --- a/h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java +++ b/h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java @@ -1,6 +1,12 @@ package hex.genmodel.algos.coxph; import org.junit.Test; +import java.nio.file.Paths; + +import hex.genmodel.MojoModel; +import hex.genmodel.easy.RowData; +import hex.genmodel.easy.EasyPredictModelWrapper; + import static org.junit.Assert.*; public class CoxPHMojoModelTest { @@ -33,4 +39,163 @@ public void testForOneCategory() { assertEquals(0.6, mojo.forOneCategory(row, 1, 0), 0); } + /* + Test backward compatibility of CoxPH mojo model trained in 3.32.1.3. + Only using all features (both categorical and numeric). + */ + @Test + public void testCoxPHBackwardCompatibilityAll332() throws Exception { + String mojofile = String.valueOf( + Paths.get( + CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip").toURI() + ).toFile() + ); + + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() + .setModel(MojoModel.load(mojofile)); + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); + + String [][] inputs = new String[][] { + {"0", "50", "1", "-17.1553730321697", "0.123203285420945","0", + "0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"}, + {"0", "6", "1", "3.83572895277207", "0.254620123203285", "0", + "0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"}, + {"0", "1", "0", "6.29705681040383", "0.265571526351814", "0", + "0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"} + }; + double [] expected = {0.0902431, 0.422815, 0.663896}; + double[] preds = new double[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + RowData row = new RowData(); + row.put("start", inputs[i][0]); + row.put("stop", inputs[i][1]); + row.put("event", inputs[i][2]); + row.put("age", inputs[i][3]); + row.put("year", inputs[i][4]); + row.put("surgery", inputs[i][5]); + row.put("transplant", inputs[i][6]); + row.put("id", inputs[i][7]); + row.put("C1", inputs[i][8]); + row.put("C2", inputs[i][9]); + row.put("C3", inputs[i][10]); + row.put("C4", inputs[i][11]); + preds[i] = model.predictCoxPH(row).value; + } + + assertArrayEquals(expected, preds, 0.000001); + } + + /* + Test backward compatibility of CoxPH mojo model trained in 3.42.0.4. + Only using all features (both categorical and numeric). + */ + @Test + public void testCoxPHBackwardCompatibilityAll342() throws Exception { + String mojofile = String.valueOf( + Paths.get( + CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip").toURI() + ).toFile() + ); + + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() + .setModel(MojoModel.load(mojofile)); + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); + + String [][] inputs = new String[][] { + {"0", "50", "1", "-17.1553730321697", "0.123203285420945","0", + "0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"}, + {"0", "6", "1", "3.83572895277207", "0.254620123203285", "0", + "0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"}, + {"0", "1", "0", "6.29705681040383", "0.265571526351814", "0", + "0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"} + }; + double [] expected = {0.0902431, 0.422815, 0.663896}; + double[] preds = new double[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + RowData row = new RowData(); + row.put("start", inputs[i][0]); + row.put("stop", inputs[i][1]); + row.put("event", inputs[i][2]); + row.put("age", inputs[i][3]); + row.put("year", inputs[i][4]); + row.put("surgery", inputs[i][5]); + row.put("transplant", inputs[i][6]); + row.put("id", inputs[i][7]); + row.put("C1", inputs[i][8]); + row.put("C2", inputs[i][9]); + row.put("C3", inputs[i][10]); + row.put("C4", inputs[i][11]); + preds[i] = model.predictCoxPH(row).value; + } + + assertArrayEquals(expected, preds, 0.000001); + } + + /* + Test backward compatibility of CoxPH mojo model trained in 3.32.1.3. + Only using categorical features. + */ + @Test + public void testCoxPHBackwardCompatibilityCatOnly332() throws Exception { + String mojofile = String.valueOf( + Paths.get( + CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip").toURI() + ).toFile() + ); + + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() + .setModel(MojoModel.load(mojofile)); + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); + + String [][] inputs = new String[][] { + {"0", "0"}, + {"0", "1"}, + {"1", "0"}, + {"1", "1"}, + }; + double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061}; + double[] preds = new double[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + RowData row = new RowData(); + row.put("surgery", inputs[i][0]); + row.put("transplant", inputs[i][1]); + preds[i] = model.predictCoxPH(row).value; + } + + assertArrayEquals(expected, preds, 0.000001); + } + + /* + Test backward compatibility of CoxPH mojo model trained in 3.42.0.4. + Only using categorical features. + */ + @Test + public void testCoxPHBackwardCompatibilityCatOnly342() throws Exception { + String mojofile = String.valueOf( + Paths.get( + CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip").toURI() + ).toFile() + ); + + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() + .setModel(MojoModel.load(mojofile)); + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); + + String [][] inputs = new String[][] { + {"0", "0"}, + {"0", "1"}, + {"1", "0"}, + {"1", "1"}, + }; + double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061}; + double[] preds = new double[inputs.length]; + for (int i = 0; i < inputs.length; i++) { + RowData row = new RowData(); + row.put("surgery", inputs[i][0]); + row.put("transplant", inputs[i][1]); + preds[i] = model.predictCoxPH(row).value; + } + + assertArrayEquals(expected, preds, 0.000001); + } } diff --git a/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip new file mode 100644 index 000000000000..38cbe9224b06 Binary files /dev/null and b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip differ diff --git a/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip new file mode 100644 index 000000000000..48ef8e01f137 Binary files /dev/null and b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip differ diff --git a/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip new file mode 100644 index 000000000000..583f16e11700 Binary files /dev/null and b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip differ diff --git a/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip new file mode 100644 index 000000000000..4bc5ddbd3319 Binary files /dev/null and b/h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip differ