Skip to content

Commit 2a3376a

Browse files
committed
allow caller to place the missing bin at the high end of the feature values
1 parent 42875bc commit 2a3376a

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

shared/libebm/PartitionOneDimensionalBoosting.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
117117
const size_t iDimension,
118118
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* const apBins,
119119
const TreeNode<bHessian>* pMissingValueTreeNode,
120-
const size_t cSlices
121-
#ifndef NDEBUG
122-
,
123-
const size_t cBins
124-
#endif // NDEBUG
125-
) {
120+
const size_t cSlices,
121+
const size_t cBins) {
126122
LOG_0(Trace_Verbose, "Entered Flatten");
127123

128124
EBM_ASSERT(nullptr != pBoosterShell);
@@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178174
pUpdateScore = aUpdateScore;
179175

180176
if(bMissing) {
177+
EBM_ASSERT(2 <= cSlices); // no cuts if there was only missing bin
178+
181179
// always put a split on the missing bin
182180
*pSplit = 1;
183181
++pSplit;
@@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
199197
const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
200198

201199
TreeNode<bHessian>* pParent = nullptr;
200+
bool bDone = false;
202201

203202
while(true) {
204203
if(UNPREDICTABLE(pTreeNode->AFTER_IsSplit())) {
@@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
253252
}
254253
EBM_ASSERT(!bNominal);
255254

256-
// if !bNominal, check the bin above and below for order
257-
EBM_ASSERT(apBins == ppBinLast || *(ppBinLast - 1) < *ppBinLast);
258-
EBM_ASSERT(ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{2} : size_t{1})) ||
259-
*ppBinLast < *(ppBinLast + 1));
260-
261255
iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);
262256

263257
while(true) { // not a real loop
@@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
267261
pMissingBin = pTreeNode->GetBin();
268262
}
269263
if(1 == iEdge) {
264+
// this cut would isolate the missing bin, but we handle those scores separately
270265
break;
271266
}
267+
} else if(TermBoostFlags_MissingHigh & flags) {
268+
++iEdge; // missing is at index 0 in the model, so we are offset by one
269+
pMissingBin = pTreeNode->GetBin();
270+
EBM_ASSERT(iEdge <= cBins + 1);
271+
if(bDone) {
272+
// this cut would isolate the missing bin, but we handle those scores separately
273+
goto done;
274+
}
272275
}
273276
}
274277

@@ -316,10 +319,13 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
316319

317320
while(true) {
318321
if(nullptr == pTreeNode) {
322+
done:;
319323
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);
320324

321325
EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
322326
if(nullptr != pMissingBin) {
327+
EBM_ASSERT(bMissing);
328+
323329
FloatScore hess = static_cast<FloatCalc>(pMissingBin->GetWeight());
324330
const auto* pGradientPair = pMissingBin->GetGradientPairs();
325331
const auto* const pGradientPairEnd = pGradientPair + cScores;
@@ -341,6 +347,18 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
341347
} while(pGradientPairEnd != pGradientPair);
342348
}
343349

350+
EBM_ASSERT(bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
351+
352+
#ifndef NDEBUG
353+
UIntSplit prevDebug = 0;
354+
for(size_t iDebug = 0; iDebug < cSlices - 1; ++iDebug) {
355+
UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer(iDimension)[iDebug];
356+
EBM_ASSERT(prevDebug < curDebug);
357+
prevDebug = curDebug;
358+
}
359+
EBM_ASSERT(prevDebug < cBins);
360+
#endif
361+
344362
LOG_0(Trace_Verbose, "Exited Flatten");
345363
return Error_None;
346364
}
@@ -353,12 +371,21 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
353371
if(bMissing) {
354372
if(TermBoostFlags_MissingLow & flags) {
355373
if(1 == iEdge) {
374+
// this cut would isolate the missing bin, but missing already has a cut
375+
break;
376+
}
377+
} else if(TermBoostFlags_MissingHigh & flags) {
378+
EBM_ASSERT(iEdge <= cBins);
379+
if(cBins == iEdge) {
380+
// This cut would isolate the missing bin, but missing already has a cut.
381+
// We still need to find the missing bin though in the tree, so continue.
356382
break;
357383
}
358384
}
359385
}
360386

361387
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
388+
EBM_ASSERT(pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
362389
*pSplit = static_cast<UIntSplit>(iEdge);
363390
++pSplit;
364391

@@ -869,9 +896,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
869896
if(!bNominal) {
870897
pMissingBin = pBin;
871898
}
872-
*ppBin = pBin;
899+
}
900+
} else if(TermBoostFlags_MissingHigh & flags) {
901+
if(bMissing) {
902+
if(!bNominal) {
903+
pMissingBin = pBin;
904+
}
873905
pBin = IndexBin(pBin, cBytesPerBin);
874-
++ppBin;
875906
}
876907
} else {
877908
if(bMissing) {
@@ -888,6 +919,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
888919
++ppBin;
889920
} while(pBinsEnd != pBin);
890921

922+
if(TermBoostFlags_MissingHigh & flags) {
923+
if(bMissing) {
924+
*ppBin = aBins;
925+
++ppBin;
926+
}
927+
}
928+
891929
if(bNominal) {
892930
std::sort(apBins,
893931
ppBin,
@@ -1072,12 +1110,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10721110
iDimension,
10731111
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins),
10741112
nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade() : nullptr,
1075-
cSlices
1076-
#ifndef NDEBUG
1077-
,
1078-
cBins
1079-
#endif // NDEBUG
1080-
);
1113+
cSlices,
1114+
cBins);
10811115

10821116
EBM_ASSERT(!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
10831117
EBM_ASSERT(!bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);

shared/libebm/tests/boosting_unusual_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2175,7 +2175,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
21752175
}
21762176

21772177
TEST_CASE("stress test, boosting") {
2178-
const double expected = 26758407585917.129;
2178+
const double expected = 26746562197367.172;
21792179

21802180
double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
21812181
CHECK(validationMetricExact == expected);

0 commit comments

Comments
 (0)