Skip to content

Commit f9acd4b

Browse files
Merge pull request #16605 from h2oai/MR/master/issue-16590
GH-16590 Accept zero as offset for xgb models trained with offset
2 parents 3b93dea + 76bb62c commit f9acd4b

4 files changed

Lines changed: 175 additions & 71 deletions

File tree

h2o-algos/src/main/java/hex/generic/GenericModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ protected double[] score0(double[] data, double[] preds) {
313313

314314
@Override
315315
protected double[] score0(double[] data, double[] preds, double offset) {
316-
if (offset == 0) // MOJO doesn't like when score0 is called with 0 offset for problems that were trained without offset
316+
if (!_output.hasOffset() && offset == 0) // MOJO doesn't like when score0 is called with 0 offset for problems that were trained without offset
317317
return score0(data, preds);
318318
else
319319
return genModel().score0(data, offset, preds);

h2o-genmodel-extensions/xgboost/src/test/java/hex/genmodel/algos/xgboost/XGBoostJavaMojoModelTest.java

Lines changed: 174 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,93 +7,197 @@
77
import hex.genmodel.algos.tree.SharedTreeGraph;
88
import hex.genmodel.algos.tree.SharedTreeMojoModel;
99
import hex.genmodel.algos.tree.SharedTreeNode;
10-
import hex.genmodel.algos.tree.SharedTreeSubgraph;
1110
import hex.genmodel.easy.EasyPredictModelWrapper;
1211
import hex.genmodel.easy.RowData;
1312
import hex.genmodel.easy.prediction.RegressionModelPrediction;
1413
import org.junit.Test;
1514

1615
import java.io.*;
16+
import java.nio.file.Paths;
1717

1818
import static org.junit.Assert.*;
1919

2020
public class XGBoostJavaMojoModelTest {
2121

22-
@Test
23-
public void testObjFunction() { // make sure we have implementation for all supported obj functions
24-
for (XGBoostMojoModel.ObjectiveType type : XGBoostMojoModel.ObjectiveType.values()) {
25-
assertNotNull(type.getId());
26-
assertFalse(type.getId().isEmpty());
27-
// check we have an implementation of ObjFunction
28-
assertNotNull(XGBoostJavaMojoModel.getObjFunction(type.getId()));
22+
@Test
23+
public void testObjFunction() { // make sure we have implementation for all supported obj functions
24+
for (XGBoostMojoModel.ObjectiveType type : XGBoostMojoModel.ObjectiveType.values()) {
25+
assertNotNull(type.getId());
26+
assertFalse(type.getId().isEmpty());
27+
// check we have an implementation of ObjFunction
28+
assertNotNull(XGBoostJavaMojoModel.getObjFunction(type.getId()));
29+
}
2930
}
30-
}
31-
32-
@Test
33-
public void testPredictContributionsSerialization() throws Exception {
34-
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
35-
XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"),
36-
MojoReaderBackendFactory.CachingStrategy.MEMORY);
37-
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
38-
PredictContributions pc = mojo.makeContributionsPredictor();
39-
assertNotNull(pc);
40-
assertTrue(deserialize(serialize(pc)) instanceof PredictContributions);
41-
}
42-
43-
@Test
44-
public void testLeafNodeAssignments() throws Exception {
45-
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
46-
getClass().getResource("xgboost_java.zip"),
47-
MojoReaderBackendFactory.CachingStrategy.MEMORY);
48-
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
49-
double[] doubles = new double[]{1, 2, 3, 4, 5, 6, 7};
50-
SharedTreeMojoModel.LeafNodeAssignments res = mojo.getLeafNodeAssignments(doubles);
51-
assertNotNull(res._nodeIds);
52-
assertNotNull(res._paths);
53-
String[] paths = mojo.getDecisionPath(doubles);
54-
assertArrayEquals(paths, res._paths);
55-
RowData data = new RowData();
56-
for (int i = 0; i< doubles.length; i++) data.put(mojo._names[i], doubles[i]);
57-
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(
58-
new EasyPredictModelWrapper.Config().setModel(mojo).setEnableLeafAssignment(true)
59-
);
60-
RegressionModelPrediction res2 = (RegressionModelPrediction) wrapper.predict(data);
61-
assertNotNull(res2.leafNodeAssignmentIds);
62-
assertNotNull(res2.leafNodeAssignments);
63-
assertArrayEquals(res._nodeIds, res2.leafNodeAssignmentIds);
64-
assertArrayEquals(res._paths, res2.leafNodeAssignments);
65-
}
66-
67-
@Test
68-
public void testConvertWithWeights() throws Exception {
69-
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
70-
XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"),
71-
MojoReaderBackendFactory.CachingStrategy.MEMORY);
72-
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
73-
SharedTreeGraph graph = mojo.convert(0, null);
74-
int expectedWeight = 380; // prostate dataset, 380 rows
75-
assertEquals(graph.subgraphArray.get(0).rootNode.getWeight(), expectedWeight, 0);
76-
double actualWeight = 0;
77-
for (SharedTreeNode node : graph.subgraphArray.get(0).nodesArray) {
78-
actualWeight += node.getWeight();
31+
32+
@Test
33+
public void testPredictContributionsSerialization() throws Exception {
34+
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
35+
XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"),
36+
MojoReaderBackendFactory.CachingStrategy.MEMORY);
37+
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
38+
PredictContributions pc = mojo.makeContributionsPredictor();
39+
assertNotNull(pc);
40+
assertTrue(deserialize(serialize(pc)) instanceof PredictContributions);
41+
}
42+
43+
@Test
44+
public void testLeafNodeAssignments() throws Exception {
45+
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
46+
getClass().getResource("xgboost_java.zip"),
47+
MojoReaderBackendFactory.CachingStrategy.MEMORY);
48+
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
49+
double[] doubles = new double[]{1, 2, 3, 4, 5, 6, 7};
50+
SharedTreeMojoModel.LeafNodeAssignments res = mojo.getLeafNodeAssignments(doubles);
51+
assertNotNull(res._nodeIds);
52+
assertNotNull(res._paths);
53+
String[] paths = mojo.getDecisionPath(doubles);
54+
assertArrayEquals(paths, res._paths);
55+
RowData data = new RowData();
56+
for (int i = 0; i < doubles.length; i++) data.put(mojo._names[i], doubles[i]);
57+
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(
58+
new EasyPredictModelWrapper.Config().setModel(mojo).setEnableLeafAssignment(true)
59+
);
60+
RegressionModelPrediction res2 = (RegressionModelPrediction) wrapper.predict(data);
61+
assertNotNull(res2.leafNodeAssignmentIds);
62+
assertNotNull(res2.leafNodeAssignments);
63+
assertArrayEquals(res._nodeIds, res2.leafNodeAssignmentIds);
64+
assertArrayEquals(res._paths, res2.leafNodeAssignments);
65+
}
66+
67+
@Test
68+
public void testConvertWithWeights() throws Exception {
69+
MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(
70+
XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"),
71+
MojoReaderBackendFactory.CachingStrategy.MEMORY);
72+
XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend);
73+
SharedTreeGraph graph = mojo.convert(0, null);
74+
int expectedWeight = 380; // prostate dataset, 380 rows
75+
assertEquals(graph.subgraphArray.get(0).rootNode.getWeight(), expectedWeight, 0);
76+
double actualWeight = 0;
77+
for (SharedTreeNode node : graph.subgraphArray.get(0).nodesArray) {
78+
actualWeight += node.getWeight();
79+
}
80+
assertEquals(expectedWeight, actualWeight, 0);
7981
}
80-
assertEquals(expectedWeight, actualWeight, 0);
81-
}
82+
83+
/*
84+
Test for XGBoost model trained with offset and predicting with none zero offset.
85+
*/
86+
@Test
87+
public void testXGBWithOffset_NoneZeroOffset() throws Exception {
88+
String mojofile = String.valueOf(
89+
Paths.get(
90+
XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithOffset.zip").toURI()
91+
).toFile()
92+
);
8293

83-
private static byte[] serialize(Object o) throws Exception {
84-
ByteArrayOutputStream bos = new ByteArrayOutputStream();
85-
try (ObjectOutput out = new ObjectOutputStream(bos)) {
86-
out.writeObject(o);
94+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
95+
.setModel(MojoModel.load(mojofile));
96+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
97+
98+
String[][] inputs = new String[][]{
99+
{"1.0", "5.0", "A"},
100+
{"2.0", "4.0", "B"},
101+
{"3.0", "3.0", "A"},
102+
};
103+
104+
double [] offsets = {0.1, 0.2, 0.3};
105+
double[] expected = {9.81646, 19.7366, 29.6913};
106+
double[] preds = new double[inputs.length];
107+
for (int i = 0; i < inputs.length; i++) {
108+
RowData row = new RowData();
109+
row.put("numeric1", inputs[i][0]);
110+
row.put("numeric2", inputs[i][1]);
111+
row.put("categorical", inputs[i][2]);
112+
preds[i] = model.predictRegression(row, offsets[i]).value;
113+
}
114+
115+
assertArrayEquals(expected, preds, 0.0001);
87116
}
88-
return bos.toByteArray();
89-
}
90117

91-
private static Object deserialize(byte[] bs) throws Exception {
92-
try (ByteArrayInputStream bis = new ByteArrayInputStream(bs)) {
93-
ObjectInput in = new ObjectInputStream(bis);
94-
return in.readObject();
118+
/*
119+
Test for XGBoost model trained with offset and predicting with zero offset.
120+
*/
121+
@Test
122+
public void testXGBWithOffset_ZeroOffset() throws Exception {
123+
String mojofile = String.valueOf(
124+
Paths.get(
125+
XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithOffset.zip").toURI()
126+
).toFile()
127+
);
128+
129+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
130+
.setModel(MojoModel.load(mojofile));
131+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
132+
133+
String[][] inputs = new String[][]{
134+
{"1.0", "5.0", "A"},
135+
{"2.0", "4.0", "B"},
136+
{"3.0", "3.0", "A"},
137+
};
138+
139+
double[] offsets = {0.0, 0.0, 0.0};
140+
double[] expected = {9.7164, 19.5366, 29.3913};
141+
double[] preds = new double[inputs.length];
142+
for (int i = 0; i < inputs.length; i++) {
143+
RowData row = new RowData();
144+
row.put("numeric1", inputs[i][0]);
145+
row.put("numeric2", inputs[i][1]);
146+
row.put("categorical", inputs[i][2]);
147+
preds[i] = model.predictRegression(row, offsets[i]).value;
148+
}
149+
150+
assertArrayEquals(expected, preds, 0.0001);
95151
}
96-
}
97152

153+
/*
154+
Test for XGBoost model trained without offset and predicting with zero offset.
155+
*/
156+
@Test
157+
public void testXGBWithoutOffset_ZeroOffset() throws Exception {
158+
String mojofile = String.valueOf(
159+
Paths.get(
160+
XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithoutOffset.zip").toURI()
161+
).toFile()
162+
);
163+
164+
EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
165+
.setModel(MojoModel.load(mojofile));
166+
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);
167+
168+
String[][] inputs = new String[][]{
169+
{"1.0", "5.0", "A"},
170+
{"2.0", "4.0", "B"},
171+
{"3.0", "3.0", "A"},
172+
};
173+
174+
double[] offsets = {0.0, 0.0, 0.0};
175+
double[] expected = {9.8089, 19.7527, 29.6871};
176+
double[] preds = new double[inputs.length];
177+
for (int i = 0; i < inputs.length; i++) {
178+
RowData row = new RowData();
179+
row.put("numeric1", inputs[i][0]);
180+
row.put("numeric2", inputs[i][1]);
181+
row.put("categorical", inputs[i][2]);
182+
preds[i] = model.predictRegression(row, offsets[i]).value;
183+
}
184+
185+
assertArrayEquals(expected, preds, 0.0001);
186+
}
187+
188+
private static byte[] serialize(Object o) throws Exception {
189+
ByteArrayOutputStream bos = new ByteArrayOutputStream();
190+
try (ObjectOutput out = new ObjectOutputStream(bos)) {
191+
out.writeObject(o);
192+
}
193+
return bos.toByteArray();
194+
}
195+
196+
private static Object deserialize(byte[] bs) throws Exception {
197+
try (ByteArrayInputStream bis = new ByteArrayInputStream(bs)) {
198+
ObjectInput in = new ObjectInputStream(bis);
199+
return in.readObject();
200+
}
201+
}
98202

99203
}

0 commit comments

Comments
 (0)