Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,27 @@ public double[] score0(double[] row, double[] predictions) {
@Override
public double[] score0(double[] row, double offset, double[] predictions) {
int[] enumOffset = null;
if (_interaction_targets != null) {
enumOffset = evaluateInteractions(row);

if (_nums == -1) {
predictions[0] = forCategories(row) + forOtherColumns(row) - forStrata(row) + offset;
} else {
if (_interaction_targets != null) {
enumOffset = evaluateInteractions(row);
}
predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset;
}
predictions[0] = forCategories(row) + forOtherColumns(row, enumOffset) - forStrata(row) + offset;
return predictions;
}

private double forOtherColumns(double[] row) {
double result = 0.0;
int catOffsetDiff = _cat_offsets[_cats] - _cats;
for(int i = _cats ; i + catOffsetDiff < _coef.length; i++) {
result += _coef[catOffsetDiff + i] * featureValue(row, i);
}
return result;
}

private double forOtherColumns(double[] row, int[] enumOffset) {
double result = 0.0;
int coefLen = _coef.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected void readModelData() throws IOException {
_model._interaction_targets = readkv("interaction_targets");
_model._interaction_column_index = new HashSet<>();
_model._interaction_column_domains = new HashMap<>();
_model._nums = readkv("num_numerical_columns");
_model._nums = readkv("num_numerical_columns", -1);
_model._num_offsets = readkv("num_offsets");

if (_model._interaction_targets != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package hex.genmodel.algos.coxph;

import org.junit.Test;
import java.nio.file.Paths;

import hex.genmodel.MojoModel;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.EasyPredictModelWrapper;

import static org.junit.Assert.*;

public class CoxPHMojoModelTest {
Expand Down Expand Up @@ -33,4 +39,163 @@ public void testForOneCategory() {
assertEquals(0.6, mojo.forOneCategory(row, 1, 0), 0);
}

/*
Test backward compatibility of CoxPH mojo model trained in 3.32.1.3.
Only using all features (both categorical and numeric).
*/
@Test
public void testCoxPHBackwardCompatibilityAll332() throws Exception {
String mojofile = String.valueOf(
Paths.get(
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_32.zip").toURI()
).toFile()
);

EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
.setModel(MojoModel.load(mojofile));
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);

String [][] inputs = new String[][] {
{"0", "50", "1", "-17.1553730321697", "0.123203285420945","0",
"0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"},
{"0", "6", "1", "3.83572895277207", "0.254620123203285", "0",
"0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"},
{"0", "1", "0", "6.29705681040383", "0.265571526351814", "0",
"0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"}
};
double [] expected = {0.0902431, 0.422815, 0.663896};
double[] preds = new double[inputs.length];
for (int i = 0; i < inputs.length; i++) {
RowData row = new RowData();
row.put("start", inputs[i][0]);
row.put("stop", inputs[i][1]);
row.put("event", inputs[i][2]);
row.put("age", inputs[i][3]);
row.put("year", inputs[i][4]);
row.put("surgery", inputs[i][5]);
row.put("transplant", inputs[i][6]);
row.put("id", inputs[i][7]);
row.put("C1", inputs[i][8]);
row.put("C2", inputs[i][9]);
row.put("C3", inputs[i][10]);
row.put("C4", inputs[i][11]);
preds[i] = model.predictCoxPH(row).value;
}

assertArrayEquals(expected, preds, 0.000001);
}

/*
Test backward compatibility of CoxPH mojo model trained in 3.42.0.4.
Only using all features (both categorical and numeric).
*/
@Test
public void testCoxPHBackwardCompatibilityAll342() throws Exception {
String mojofile = String.valueOf(
Paths.get(
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_all_3_42.zip").toURI()
).toFile()
);

EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
.setModel(MojoModel.load(mojofile));
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);

String [][] inputs = new String[][] {
{"0", "50", "1", "-17.1553730321697", "0.123203285420945","0",
"0", "1", "-2.3004572363122033", "-23.840464872142775", "c0.l2", "c1.l2"},
{"0", "6", "1", "3.83572895277207", "0.254620123203285", "0",
"0", "2", "33.01605679101838", "-12.944002705270874" ,"c0.l1","c1.l2"},
{"0", "1", "0", "6.29705681040383", "0.265571526351814", "0",
"0", "3", "-22.54601829241052", "77.61631885563669", "c0.l2", "c1.l0"}
};
double [] expected = {0.0902431, 0.422815, 0.663896};
double[] preds = new double[inputs.length];
for (int i = 0; i < inputs.length; i++) {
RowData row = new RowData();
row.put("start", inputs[i][0]);
row.put("stop", inputs[i][1]);
row.put("event", inputs[i][2]);
row.put("age", inputs[i][3]);
row.put("year", inputs[i][4]);
row.put("surgery", inputs[i][5]);
row.put("transplant", inputs[i][6]);
row.put("id", inputs[i][7]);
row.put("C1", inputs[i][8]);
row.put("C2", inputs[i][9]);
row.put("C3", inputs[i][10]);
row.put("C4", inputs[i][11]);
preds[i] = model.predictCoxPH(row).value;
}

assertArrayEquals(expected, preds, 0.000001);
}

/*
Test backward compatibility of CoxPH mojo model trained in 3.32.1.3.
Only using categorical features.
*/
@Test
public void testCoxPHBackwardCompatibilityCatOnly332() throws Exception {
String mojofile = String.valueOf(
Paths.get(
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_32.zip").toURI()
).toFile()
);

EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
.setModel(MojoModel.load(mojofile));
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);

String [][] inputs = new String[][] {
{"0", "0"},
{"0", "1"},
{"1", "0"},
{"1", "1"},
};
double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061};
double[] preds = new double[inputs.length];
for (int i = 0; i < inputs.length; i++) {
RowData row = new RowData();
row.put("surgery", inputs[i][0]);
row.put("transplant", inputs[i][1]);
preds[i] = model.predictCoxPH(row).value;
}

assertArrayEquals(expected, preds, 0.000001);
}

/*
Test backward compatibility of CoxPH mojo model trained in 3.42.0.4.
Only using categorical features.
*/
@Test
public void testCoxPHBackwardCompatibilityCatOnly342() throws Exception {
String mojofile = String.valueOf(
Paths.get(
CoxPHMojoModelTest.class.getClassLoader().getResource("hex/genmodel/algos/coxph/CoxPH_bc_catOnly_3_42.zip").toURI()
).toFile()
);

EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config()
.setModel(MojoModel.load(mojofile));
EasyPredictModelWrapper model = new EasyPredictModelWrapper(config);

String [][] inputs = new String[][] {
{"0", "0"},
{"0", "1"},
{"1", "0"},
{"1", "1"},
};
double [] expected = {0.0628001, 0.221134, -0.686394, -0.528061};
double[] preds = new double[inputs.length];
for (int i = 0; i < inputs.length; i++) {
RowData row = new RowData();
row.put("surgery", inputs[i][0]);
row.put("transplant", inputs[i][1]);
preds[i] = model.predictCoxPH(row).value;
}

assertArrayEquals(expected, preds, 0.000001);
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading