|
7 | 7 | import hex.genmodel.algos.tree.SharedTreeGraph; |
8 | 8 | import hex.genmodel.algos.tree.SharedTreeMojoModel; |
9 | 9 | import hex.genmodel.algos.tree.SharedTreeNode; |
10 | | -import hex.genmodel.algos.tree.SharedTreeSubgraph; |
11 | 10 | import hex.genmodel.easy.EasyPredictModelWrapper; |
12 | 11 | import hex.genmodel.easy.RowData; |
13 | 12 | import hex.genmodel.easy.prediction.RegressionModelPrediction; |
14 | 13 | import org.junit.Test; |
15 | 14 |
|
16 | 15 | import java.io.*; |
| 16 | +import java.nio.file.Paths; |
17 | 17 |
|
18 | 18 | import static org.junit.Assert.*; |
19 | 19 |
|
20 | 20 | public class XGBoostJavaMojoModelTest { |
21 | 21 |
|
22 | | - @Test |
23 | | - public void testObjFunction() { // make sure we have implementation for all supported obj functions |
24 | | - for (XGBoostMojoModel.ObjectiveType type : XGBoostMojoModel.ObjectiveType.values()) { |
25 | | - assertNotNull(type.getId()); |
26 | | - assertFalse(type.getId().isEmpty()); |
27 | | - // check we have an implementation of ObjFunction |
28 | | - assertNotNull(XGBoostJavaMojoModel.getObjFunction(type.getId())); |
| 22 | + @Test |
| 23 | + public void testObjFunction() { // make sure we have implementation for all supported obj functions |
| 24 | + for (XGBoostMojoModel.ObjectiveType type : XGBoostMojoModel.ObjectiveType.values()) { |
| 25 | + assertNotNull(type.getId()); |
| 26 | + assertFalse(type.getId().isEmpty()); |
| 27 | + // check we have an implementation of ObjFunction |
| 28 | + assertNotNull(XGBoostJavaMojoModel.getObjFunction(type.getId())); |
| 29 | + } |
29 | 30 | } |
30 | | - } |
31 | | - |
32 | | - @Test |
33 | | - public void testPredictContributionsSerialization() throws Exception { |
34 | | - MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
35 | | - XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), |
36 | | - MojoReaderBackendFactory.CachingStrategy.MEMORY); |
37 | | - XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
38 | | - PredictContributions pc = mojo.makeContributionsPredictor(); |
39 | | - assertNotNull(pc); |
40 | | - assertTrue(deserialize(serialize(pc)) instanceof PredictContributions); |
41 | | - } |
42 | | - |
43 | | - @Test |
44 | | - public void testLeafNodeAssignments() throws Exception { |
45 | | - MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
46 | | - getClass().getResource("xgboost_java.zip"), |
47 | | - MojoReaderBackendFactory.CachingStrategy.MEMORY); |
48 | | - XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
49 | | - double[] doubles = new double[]{1, 2, 3, 4, 5, 6, 7}; |
50 | | - SharedTreeMojoModel.LeafNodeAssignments res = mojo.getLeafNodeAssignments(doubles); |
51 | | - assertNotNull(res._nodeIds); |
52 | | - assertNotNull(res._paths); |
53 | | - String[] paths = mojo.getDecisionPath(doubles); |
54 | | - assertArrayEquals(paths, res._paths); |
55 | | - RowData data = new RowData(); |
56 | | - for (int i = 0; i< doubles.length; i++) data.put(mojo._names[i], doubles[i]); |
57 | | - EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper( |
58 | | - new EasyPredictModelWrapper.Config().setModel(mojo).setEnableLeafAssignment(true) |
59 | | - ); |
60 | | - RegressionModelPrediction res2 = (RegressionModelPrediction) wrapper.predict(data); |
61 | | - assertNotNull(res2.leafNodeAssignmentIds); |
62 | | - assertNotNull(res2.leafNodeAssignments); |
63 | | - assertArrayEquals(res._nodeIds, res2.leafNodeAssignmentIds); |
64 | | - assertArrayEquals(res._paths, res2.leafNodeAssignments); |
65 | | - } |
66 | | - |
67 | | - @Test |
68 | | - public void testConvertWithWeights() throws Exception { |
69 | | - MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
70 | | - XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), |
71 | | - MojoReaderBackendFactory.CachingStrategy.MEMORY); |
72 | | - XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
73 | | - SharedTreeGraph graph = mojo.convert(0, null); |
74 | | - int expectedWeight = 380; // prostate dataset, 380 rows |
75 | | - assertEquals(graph.subgraphArray.get(0).rootNode.getWeight(), expectedWeight, 0); |
76 | | - double actualWeight = 0; |
77 | | - for (SharedTreeNode node : graph.subgraphArray.get(0).nodesArray) { |
78 | | - actualWeight += node.getWeight(); |
| 31 | + |
| 32 | + @Test |
| 33 | + public void testPredictContributionsSerialization() throws Exception { |
| 34 | + MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
| 35 | + XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), |
| 36 | + MojoReaderBackendFactory.CachingStrategy.MEMORY); |
| 37 | + XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
| 38 | + PredictContributions pc = mojo.makeContributionsPredictor(); |
| 39 | + assertNotNull(pc); |
| 40 | + assertTrue(deserialize(serialize(pc)) instanceof PredictContributions); |
| 41 | + } |
| 42 | + |
| 43 | + @Test |
| 44 | + public void testLeafNodeAssignments() throws Exception { |
| 45 | + MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
| 46 | + getClass().getResource("xgboost_java.zip"), |
| 47 | + MojoReaderBackendFactory.CachingStrategy.MEMORY); |
| 48 | + XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
| 49 | + double[] doubles = new double[]{1, 2, 3, 4, 5, 6, 7}; |
| 50 | + SharedTreeMojoModel.LeafNodeAssignments res = mojo.getLeafNodeAssignments(doubles); |
| 51 | + assertNotNull(res._nodeIds); |
| 52 | + assertNotNull(res._paths); |
| 53 | + String[] paths = mojo.getDecisionPath(doubles); |
| 54 | + assertArrayEquals(paths, res._paths); |
| 55 | + RowData data = new RowData(); |
| 56 | + for (int i = 0; i < doubles.length; i++) data.put(mojo._names[i], doubles[i]); |
| 57 | + EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper( |
| 58 | + new EasyPredictModelWrapper.Config().setModel(mojo).setEnableLeafAssignment(true) |
| 59 | + ); |
| 60 | + RegressionModelPrediction res2 = (RegressionModelPrediction) wrapper.predict(data); |
| 61 | + assertNotNull(res2.leafNodeAssignmentIds); |
| 62 | + assertNotNull(res2.leafNodeAssignments); |
| 63 | + assertArrayEquals(res._nodeIds, res2.leafNodeAssignmentIds); |
| 64 | + assertArrayEquals(res._paths, res2.leafNodeAssignments); |
| 65 | + } |
| 66 | + |
| 67 | + @Test |
| 68 | + public void testConvertWithWeights() throws Exception { |
| 69 | + MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend( |
| 70 | + XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), |
| 71 | + MojoReaderBackendFactory.CachingStrategy.MEMORY); |
| 72 | + XGBoostJavaMojoModel mojo = (XGBoostJavaMojoModel) MojoModel.load(readerBackend); |
| 73 | + SharedTreeGraph graph = mojo.convert(0, null); |
| 74 | + int expectedWeight = 380; // prostate dataset, 380 rows |
| 75 | + assertEquals(graph.subgraphArray.get(0).rootNode.getWeight(), expectedWeight, 0); |
| 76 | + double actualWeight = 0; |
| 77 | + for (SharedTreeNode node : graph.subgraphArray.get(0).nodesArray) { |
| 78 | + actualWeight += node.getWeight(); |
| 79 | + } |
| 80 | + assertEquals(expectedWeight, actualWeight, 0); |
79 | 81 | } |
80 | | - assertEquals(expectedWeight, actualWeight, 0); |
81 | | - } |
| 82 | + |
| 83 | + /* |
| 84 | + Test for XGBoost model trained with offset and predicting with none zero offset. |
| 85 | + */ |
| 86 | + @Test |
| 87 | + public void testXGBWithOffset_NoneZeroOffset() throws Exception { |
| 88 | + String mojofile = String.valueOf( |
| 89 | + Paths.get( |
| 90 | + XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithOffset.zip").toURI() |
| 91 | + ).toFile() |
| 92 | + ); |
82 | 93 |
|
83 | | - private static byte[] serialize(Object o) throws Exception { |
84 | | - ByteArrayOutputStream bos = new ByteArrayOutputStream(); |
85 | | - try (ObjectOutput out = new ObjectOutputStream(bos)) { |
86 | | - out.writeObject(o); |
| 94 | + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() |
| 95 | + .setModel(MojoModel.load(mojofile)); |
| 96 | + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); |
| 97 | + |
| 98 | + String[][] inputs = new String[][]{ |
| 99 | + {"1.0", "5.0", "A"}, |
| 100 | + {"2.0", "4.0", "B"}, |
| 101 | + {"3.0", "3.0", "A"}, |
| 102 | + }; |
| 103 | + |
| 104 | + double [] offsets = {0.1, 0.2, 0.3}; |
| 105 | + double[] expected = {9.81646, 19.7366, 29.6913}; |
| 106 | + double[] preds = new double[inputs.length]; |
| 107 | + for (int i = 0; i < inputs.length; i++) { |
| 108 | + RowData row = new RowData(); |
| 109 | + row.put("numeric1", inputs[i][0]); |
| 110 | + row.put("numeric2", inputs[i][1]); |
| 111 | + row.put("categorical", inputs[i][2]); |
| 112 | + preds[i] = model.predictRegression(row, offsets[i]).value; |
| 113 | + } |
| 114 | + |
| 115 | + assertArrayEquals(expected, preds, 0.0001); |
87 | 116 | } |
88 | | - return bos.toByteArray(); |
89 | | - } |
90 | 117 |
|
91 | | - private static Object deserialize(byte[] bs) throws Exception { |
92 | | - try (ByteArrayInputStream bis = new ByteArrayInputStream(bs)) { |
93 | | - ObjectInput in = new ObjectInputStream(bis); |
94 | | - return in.readObject(); |
| 118 | + /* |
| 119 | + Test for XGBoost model trained with offset and predicting with zero offset. |
| 120 | + */ |
| 121 | + @Test |
| 122 | + public void testXGBWithOffset_ZeroOffset() throws Exception { |
| 123 | + String mojofile = String.valueOf( |
| 124 | + Paths.get( |
| 125 | + XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithOffset.zip").toURI() |
| 126 | + ).toFile() |
| 127 | + ); |
| 128 | + |
| 129 | + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() |
| 130 | + .setModel(MojoModel.load(mojofile)); |
| 131 | + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); |
| 132 | + |
| 133 | + String[][] inputs = new String[][]{ |
| 134 | + {"1.0", "5.0", "A"}, |
| 135 | + {"2.0", "4.0", "B"}, |
| 136 | + {"3.0", "3.0", "A"}, |
| 137 | + }; |
| 138 | + |
| 139 | + double[] offsets = {0.0, 0.0, 0.0}; |
| 140 | + double[] expected = {9.7164, 19.5366, 29.3913}; |
| 141 | + double[] preds = new double[inputs.length]; |
| 142 | + for (int i = 0; i < inputs.length; i++) { |
| 143 | + RowData row = new RowData(); |
| 144 | + row.put("numeric1", inputs[i][0]); |
| 145 | + row.put("numeric2", inputs[i][1]); |
| 146 | + row.put("categorical", inputs[i][2]); |
| 147 | + preds[i] = model.predictRegression(row, offsets[i]).value; |
| 148 | + } |
| 149 | + |
| 150 | + assertArrayEquals(expected, preds, 0.0001); |
95 | 151 | } |
96 | | - } |
97 | 152 |
|
| 153 | + /* |
| 154 | + Test for XGBoost model trained without offset and predicting with zero offset. |
| 155 | + */ |
| 156 | + @Test |
| 157 | + public void testXGBWithoutOffset_ZeroOffset() throws Exception { |
| 158 | + String mojofile = String.valueOf( |
| 159 | + Paths.get( |
| 160 | + XGBoostJavaMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/xgboost/XGBoostWithoutOffset.zip").toURI() |
| 161 | + ).toFile() |
| 162 | + ); |
| 163 | + |
| 164 | + EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config() |
| 165 | + .setModel(MojoModel.load(mojofile)); |
| 166 | + EasyPredictModelWrapper model = new EasyPredictModelWrapper(config); |
| 167 | + |
| 168 | + String[][] inputs = new String[][]{ |
| 169 | + {"1.0", "5.0", "A"}, |
| 170 | + {"2.0", "4.0", "B"}, |
| 171 | + {"3.0", "3.0", "A"}, |
| 172 | + }; |
| 173 | + |
| 174 | + double[] offsets = {0.0, 0.0, 0.0}; |
| 175 | + double[] expected = {9.8089, 19.7527, 29.6871}; |
| 176 | + double[] preds = new double[inputs.length]; |
| 177 | + for (int i = 0; i < inputs.length; i++) { |
| 178 | + RowData row = new RowData(); |
| 179 | + row.put("numeric1", inputs[i][0]); |
| 180 | + row.put("numeric2", inputs[i][1]); |
| 181 | + row.put("categorical", inputs[i][2]); |
| 182 | + preds[i] = model.predictRegression(row, offsets[i]).value; |
| 183 | + } |
| 184 | + |
| 185 | + assertArrayEquals(expected, preds, 0.0001); |
| 186 | + } |
| 187 | + |
| 188 | + private static byte[] serialize(Object o) throws Exception { |
| 189 | + ByteArrayOutputStream bos = new ByteArrayOutputStream(); |
| 190 | + try (ObjectOutput out = new ObjectOutputStream(bos)) { |
| 191 | + out.writeObject(o); |
| 192 | + } |
| 193 | + return bos.toByteArray(); |
| 194 | + } |
| 195 | + |
| 196 | + private static Object deserialize(byte[] bs) throws Exception { |
| 197 | + try (ByteArrayInputStream bis = new ByteArrayInputStream(bs)) { |
| 198 | + ObjectInput in = new ObjectInputStream(bis); |
| 199 | + return in.readObject(); |
| 200 | + } |
| 201 | + } |
98 | 202 |
|
99 | 203 | } |
0 commit comments