Skip to content

Commit 26cd3ab

Browse files
Mathanraj-Sharmavalenad1
authored andcommitted
Add catOnly test for coxPH
1 parent 28ce8b9 commit 26cd3ab

5 files changed

Lines changed: 119 additions & 2 deletions

File tree

h2o-genmodel/src/test/java/hex/genmodel/algos/coxph/CoxPHMojoModelTest.java

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@ public void testForOneCategory() {
3939
assertEquals(0.6, mojo.forOneCategory(row, 1, 0), 0);
4040
}
4141

42+
/*
43+
Test backward compatibility of CoxPH mojo model trained in 3.32.1.3.
44+
Only using all features (both categorical and numeric).
45+
*/
4246
@Test
43-
public void testCoxPHBackwardCompatibility() throws Exception {
47+
public void testCoxPHBackwardCompatibilityAll332() throws Exception {
4448
String mojofile = String.valueOf(
4549
Paths.get(
46-
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc.zip").toURI()
50+
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip").toURI()
4751
).toFile()
4852
);
4953

@@ -81,4 +85,117 @@ public void testCoxPHBackwardCompatibility() throws Exception {
8185
assertArrayEquals(expected, preds, 0.000001);
8286
}
8387

88+
/*
89+
Test backward compatibility of CoxPH mojo model trained in 3.42.0.4.
90+
Only using all features (both categorical and numeric).
91+
*/
92+
@Test
93+
public void testCoxPHBackwardCompatibilityAll342() throws Exception {
94+
String mojofile = String.valueOf(
95+
Paths.get(
96+
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip").toURI()
97+
).toFile()
98+
);
99+
100+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
101+
.setModel(MojoModel.load(mojofile));
102+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
103+
104+
String [][] inputs = new String[][] {
105+
{"0", "50", "1", "-17.1553730321697", "0.123203285420945","0",
106+
"0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"},
107+
{"0", "6", "1", "3.83572895277207", "0.254620123203285", "0",
108+
"0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"},
109+
{"0", "1", "0", "6.29705681040383", "0.265571526351814", "0",
110+
"0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"}
111+
};
112+
double [] expected = {0.0902431, 0.422815, 0.663896};
113+
double[] preds = new double[inputs.length];
114+
for (int i = 0; i < inputs.length; i++) {
115+
RowData row = new RowData();
116+
row.put("start", inputs[i][0]);
117+
row.put("stop", inputs[i][1]);
118+
row.put("event", inputs[i][2]);
119+
row.put("age", inputs[i][3]);
120+
row.put("year", inputs[i][4]);
121+
row.put("surgery", inputs[i][5]);
122+
row.put("transplant", inputs[i][6]);
123+
row.put("id", inputs[i][7]);
124+
row.put("C1", inputs[i][8]);
125+
row.put("C2", inputs[i][9]);
126+
row.put("C3", inputs[i][10]);
127+
row.put("C4", inputs[i][11]);
128+
preds[i] = model.predictCoxPH(row).value;
129+
}
130+
131+
assertArrayEquals(expected, preds, 0.000001);
132+
}
133+
134+
/*
135+
Test backward compatibility of CoxPH mojo model trained in 3.32.1.3.
136+
Only using categorical features.
137+
*/
138+
@Test
139+
public void testCoxPHBackwardCompatibilityCatOnly332() throws Exception {
140+
String mojofile = String.valueOf(
141+
Paths.get(
142+
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip").toURI()
143+
).toFile()
144+
);
145+
146+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
147+
.setModel(MojoModel.load(mojofile));
148+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
149+
150+
String [][] inputs = new String[][] {
151+
{"0", "0"},
152+
{"0", "1"},
153+
{"1", "0"},
154+
{"1", "1"},
155+
};
156+
double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061};
157+
double[] preds = new double[inputs.length];
158+
for (int i = 0; i < inputs.length; i++) {
159+
RowData row = new RowData();
160+
row.put("surgery", inputs[i][0]);
161+
row.put("transplant", inputs[i][1]);
162+
preds[i] = model.predictCoxPH(row).value;
163+
}
164+
165+
assertArrayEquals(expected, preds, 0.000001);
166+
}
167+
168+
/*
169+
Test backward compatibility of CoxPH mojo model trained in 3.42.0.4.
170+
Only using categorical features.
171+
*/
172+
@Test
173+
public void testCoxPHBackwardCompatibilityCatOnly342() throws Exception {
174+
String mojofile = String.valueOf(
175+
Paths.get(
176+
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip").toURI()
177+
).toFile()
178+
);
179+
180+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
181+
.setModel(MojoModel.load(mojofile));
182+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
183+
184+
String [][] inputs = new String[][] {
185+
{"0", "0"},
186+
{"0", "1"},
187+
{"1", "0"},
188+
{"1", "1"},
189+
};
190+
double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061};
191+
double[] preds = new double[inputs.length];
192+
for (int i = 0; i < inputs.length; i++) {
193+
RowData row = new RowData();
194+
row.put("surgery", inputs[i][0]);
195+
row.put("transplant", inputs[i][1]);
196+
preds[i] = model.predictCoxPH(row).value;
197+
}
198+
199+
assertArrayEquals(expected, preds, 0.000001);
200+
}
84201
}

h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc.zip renamed to h2o-genmodel/src/test/resources/hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip

File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)