Skip to content

Commit 1cbaeb8

Browse files
committed
fix inconsistent results in ARM exact tests by removing purification from tests
1 parent ac6452e commit 1cbaeb8

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

shared/libebm/tests/boosting_unusual_inputs.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -2088,9 +2088,11 @@ TEST_CASE("stress test, boosting") {
20882088
// terms.push_back({0, 1, 2, 3}); // TODO: enable when fast enough
20892089
}
20902090
const size_t cRounds = 200;
2091-
std::vector<IntEbm> boostFlagsAny{TermBoostFlags_PurifyGain,
2091+
std::vector<IntEbm> boostFlagsAny{// TermBoostFlags_PurifyGain,
20922092
TermBoostFlags_DisableNewtonGain,
20932093
TermBoostFlags_DisableCategorical,
2094+
// TermBoostFlags_PurifyUpdate,
2095+
// TermBoostFlags_GradientSums, // does not return a metric
20942096
TermBoostFlags_DisableNewtonUpdate,
20952097
TermBoostFlags_RandomSplits};
20962098
std::vector<IntEbm> boostFlagsChoose{TermBoostFlags_Default,
@@ -2099,10 +2101,10 @@ TEST_CASE("stress test, boosting") {
20992101
TermBoostFlags_MissingSeparate,
21002102
TermBoostFlags_MissingDrop};
21012103

2102-
double validationMetric = 0.0;
2104+
double validationMetric = 1.0;
21032105

21042106
for(IntEbm classesCount = Task_Regression; classesCount < 5; ++classesCount) {
2105-
if(classesCount != Task_Regression && classesCount < 2) {
2107+
if(classesCount != Task_Regression && classesCount < 1) {
21062108
continue;
21072109
}
21082110
const auto train = MakeRandomDataset(rng, classesCount, cTrainSamples, features);
@@ -2159,9 +2161,13 @@ TEST_CASE("stress test, boosting") {
21592161
.validationMetric;
21602162
}
21612163
}
2162-
validationMetric += validationMetricIteration;
2164+
if(classesCount == 1) {
2165+
CHECK(std::numeric_limits<double>::infinity() == validationMetricIteration);
2166+
} else {
2167+
validationMetric *= validationMetricIteration;
2168+
}
21632169
}
21642170
}
21652171

2166-
CHECK(validationMetric == 42031.143270308334);
2172+
CHECK(validationMetric == 62013566170252.117);
21672173
}

0 commit comments

Comments
 (0)