@@ -2088,9 +2088,11 @@ TEST_CASE("stress test, boosting") {
2088
2088
// terms.push_back({0, 1, 2, 3}); // TODO: enable when fast enough
2089
2089
}
2090
2090
const size_t cRounds = 200 ;
2091
- std::vector<IntEbm> boostFlagsAny{TermBoostFlags_PurifyGain,
2091
+ std::vector<IntEbm> boostFlagsAny{// TermBoostFlags_PurifyGain,
2092
2092
TermBoostFlags_DisableNewtonGain,
2093
2093
TermBoostFlags_DisableCategorical,
2094
+ // TermBoostFlags_PurifyUpdate,
2095
+ // TermBoostFlags_GradientSums, // does not return a metric
2094
2096
TermBoostFlags_DisableNewtonUpdate,
2095
2097
TermBoostFlags_RandomSplits};
2096
2098
std::vector<IntEbm> boostFlagsChoose{TermBoostFlags_Default,
@@ -2099,10 +2101,10 @@ TEST_CASE("stress test, boosting") {
2099
2101
TermBoostFlags_MissingSeparate,
2100
2102
TermBoostFlags_MissingDrop};
2101
2103
2102
- double validationMetric = 0 .0 ;
2104
+ double validationMetric = 1 .0 ;
2103
2105
2104
2106
for (IntEbm classesCount = Task_Regression; classesCount < 5 ; ++classesCount) {
2105
- if (classesCount != Task_Regression && classesCount < 2 ) {
2107
+ if (classesCount != Task_Regression && classesCount < 1 ) {
2106
2108
continue ;
2107
2109
}
2108
2110
const auto train = MakeRandomDataset (rng, classesCount, cTrainSamples, features);
@@ -2159,9 +2161,13 @@ TEST_CASE("stress test, boosting") {
2159
2161
.validationMetric ;
2160
2162
}
2161
2163
}
2162
- validationMetric += validationMetricIteration;
2164
+ if (classesCount == 1 ) {
2165
+ CHECK (std::numeric_limits<double >::infinity () == validationMetricIteration);
2166
+ } else {
2167
+ validationMetric *= validationMetricIteration;
2168
+ }
2163
2169
}
2164
2170
}
2165
2171
2166
- CHECK (validationMetric == 42031.143270308334 );
2172
+ CHECK (validationMetric == 62013566170252.117 );
2167
2173
}
0 commit comments