Skip to content

Commit ea67b29

Browse files
Fix NPE when scoring CoxPH mojo from h2o 3.32.x.x
1 parent 9aa1ad8 commit ea67b29

File tree

4 files changed

+66
-4
lines changed

4 files changed

+66
-4
lines changed

h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoModel.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,27 @@ public double[] score0(double[] row, double[] predictions) {
7878
@Override
7979
public double[] score0(double[] row, double offset, double[] predictions) {
8080
int[] enumOffset = null;
81-
if (_interaction_targets != null) {
82-
enumOffset = evaluateInteractions(row);
81+
82+
if (_nums == -1) {
83+
predictions[0] = forCategories(row) + forOtherColumns(row) - forStrata(row) + offset;
84+
} else {
85+
if (_interaction_targets != null) {
86+
enumOffset = evaluateInteractions(row);
87+
}
88+
predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset;
8389
}
84-
predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset;
8590
return predictions;
8691
}
8792

93+
private double forOtherColumns(double[] row) {
94+
double result = 0.0;
95+
int catOffsetDiff = _cat_offsets[_cats] - _cats;
96+
for(int i = _cats ; i + catOffsetDiff < _coef.length; i++) {
97+
result += _coef[catOffsetDiff + i] * featureValue(row, i);
98+
}
99+
return result;
100+
}
101+
88102
private double forOtherColumns(double[] row, int[] enumOffset) {
89103
double result = 0.0;
90104
int coefLen = _coef.length;

h2o-genmodel/src/main/java/hex/genmodel/algos/coxph/CoxPHMojoReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ protected void readModelData() throws IOException {
2828
_model._interaction_targets = readkv("interaction_targets");
2929
_model._interaction_column_index = new HashSet<>();
3030
_model._interaction_column_domains = new HashMap<>();
31-
_model._nums = readkv("num_numerical_columns");
31+
_model._nums = readkv("num_numerical_columns", -1);
3232
_model._num_offsets = readkv("num_offsets");
3333

3434
if (_model._interaction_targets != null) {

h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package hex.genmodel.algos.coxph;
22

33
import org.junit.Test;
4+
import java.nio.file.Paths;
5+
6+
import hex.genmodel.MojoModel;
7+
import hex.genmodel.easy.RowData;
8+
import hex.genmodel.easy.EasyPredictModelWrapper;
9+
410
import static org.junit.Assert.*;
511

612
public class CoxPHMojoModelTest {
@@ -33,4 +39,46 @@ public void testForOneCategory() {
3339
assertEquals(0.6, mojo.forOneCategory(row, 1, 0), 0);
3440
}
3541

42+
@Test
43+
public void testCoxPHBackwardCompatibility() throws Exception {
44+
String mojofile = String.valueOf(
45+
Paths.get(
46+
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc.zip").toURI()
47+
).toFile()
48+
);
49+
50+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
51+
.setModel(MojoModel.load(mojofile));
52+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
53+
54+
String [][] inputs = new String[][] {
55+
{"0", "50", "1", "-17.1553730321697", "0.123203285420945","0",
56+
"0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"},
57+
{"0", "6", "1", "3.83572895277207", "0.254620123203285", "0",
58+
"0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"},
59+
{"0", "1", "0", "6.29705681040383", "0.265571526351814", "0",
60+
"0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"}
61+
};
62+
double [] expected = {0.0902431, 0.422815, 0.663896};
63+
double[] preds = new double[inputs.length];
64+
for (int i = 0; i < inputs.length; i++) {
65+
RowData row = new RowData();
66+
row.put("start", inputs[i][0]);
67+
row.put("stop", inputs[i][1]);
68+
row.put("event", inputs[i][2]);
69+
row.put("age", inputs[i][3]);
70+
row.put("year", inputs[i][4]);
71+
row.put("surgery", inputs[i][5]);
72+
row.put("transplant", inputs[i][6]);
73+
row.put("id", inputs[i][7]);
74+
row.put("C1", inputs[i][8]);
75+
row.put("C2", inputs[i][9]);
76+
row.put("C3", inputs[i][10]);
77+
row.put("C4", inputs[i][11]);
78+
preds[i] = model.predictCoxPH(row).value;
79+
}
80+
81+
assertArrayEquals(expected, preds, 0.000001);
82+
}
83+
3684
}
Binary file not shown.

0 commit comments

Comments
 (0)