Skip to content

Commit 5e999b6

Browse files
committed
allow caller to place the missing bin at the high end of the feature values
1 parent 42875bc commit 5e999b6

File tree

2 files changed

+63
-26
lines changed

2 files changed

+63
-26
lines changed

shared/libebm/PartitionOneDimensionalBoosting.cpp

+62-25
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
117117
const size_t iDimension,
118118
const Bin<FloatMain, UIntMain, true, true, bHessian>* const* const apBins,
119119
const TreeNode<bHessian>* pMissingValueTreeNode,
120-
const size_t cSlices
121-
#ifndef NDEBUG
122-
,
123-
const size_t cBins
124-
#endif // NDEBUG
125-
) {
120+
const size_t cSlices,
121+
const size_t cBins) {
126122
LOG_0(Trace_Verbose, "Entered Flatten");
127123

128124
EBM_ASSERT(nullptr != pBoosterShell);
@@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178174
pUpdateScore = aUpdateScore;
179175

180176
if(bMissing) {
177+
EBM_ASSERT(2 <= cSlices); // no cuts if there was only missing bin
178+
181179
// always put a split on the missing bin
182180
*pSplit = 1;
183181
++pSplit;
@@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
199197
const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
200198

201199
TreeNode<bHessian>* pParent = nullptr;
200+
bool bDone = false;
202201

203202
while(true) {
204203
if(UNPREDICTABLE(pTreeNode->AFTER_IsSplit())) {
@@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
253252
}
254253
EBM_ASSERT(!bNominal);
255254

256-
// if !bNominal, check the bin above and below for order
257-
EBM_ASSERT(apBins == ppBinLast || *(ppBinLast - 1) < *ppBinLast);
258-
EBM_ASSERT(ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{2} : size_t{1})) ||
259-
*ppBinLast < *(ppBinLast + 1));
260-
261255
iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);
262256

263257
while(true) { // not a real loop
@@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
267261
pMissingBin = pTreeNode->GetBin();
268262
}
269263
if(1 == iEdge) {
264+
// this cut would isolate the missing bin, but we handle those scores separately
270265
break;
271266
}
267+
} else if(TermBoostFlags_MissingHigh & flags) {
268+
++iEdge; // missing is at index 0 in the model, so we are offset by one
269+
pMissingBin = pTreeNode->GetBin();
270+
EBM_ASSERT(iEdge <= cBins + 1);
271+
if(bDone) {
272+
// this cut would isolate the missing bin, but we handle those scores separately
273+
goto done;
274+
}
272275
}
273276
}
274277

@@ -290,6 +293,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
290293
deltaStepMax);
291294
}
292295

296+
EBM_ASSERT(pUpdateScore < aUpdateScore + cScores * cSlices);
293297
*pUpdateScore = static_cast<FloatScore>(updateScore);
294298
++pUpdateScore;
295299

@@ -316,10 +320,27 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
316320

317321
while(true) {
318322
if(nullptr == pTreeNode) {
323+
done:;
319324
EBM_ASSERT(cSamplesTotalDebug == cSamplesExpectedDebug);
320325

326+
EBM_ASSERT(bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);
327+
328+
EBM_ASSERT(bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
329+
330+
#ifndef NDEBUG
331+
UIntSplit prevDebug = 0;
332+
for(size_t iDebug = 0; iDebug < cSlices - 1; ++iDebug) {
333+
UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer(iDimension)[iDebug];
334+
EBM_ASSERT(prevDebug < curDebug);
335+
prevDebug = curDebug;
336+
}
337+
EBM_ASSERT(prevDebug < cBins);
338+
#endif
339+
321340
EBM_ASSERT(nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
322341
if(nullptr != pMissingBin) {
342+
EBM_ASSERT(bMissing);
343+
323344
FloatScore hess = static_cast<FloatCalc>(pMissingBin->GetWeight());
324345
const auto* pGradientPair = pMissingBin->GetGradientPairs();
325346
const auto* const pGradientPairEnd = pGradientPair + cScores;
@@ -353,12 +374,22 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
353374
if(bMissing) {
354375
if(TermBoostFlags_MissingLow & flags) {
355376
if(1 == iEdge) {
377+
// this cut would isolate the missing bin, but missing already has a cut
378+
break;
379+
}
380+
} else if(TermBoostFlags_MissingHigh & flags) {
381+
EBM_ASSERT(iEdge <= cBins);
382+
if(cBins == iEdge) {
383+
// This cut would isolate the missing bin, but missing already has a cut.
384+
// We still need to find the missing bin though in the tree, so continue.
385+
bDone = true;
356386
break;
357387
}
358388
}
359389
}
360390

361391
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
392+
EBM_ASSERT(pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer(iDimension));
362393
*pSplit = static_cast<UIntSplit>(iEdge);
363394
++pSplit;
364395

@@ -865,13 +896,14 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
865896

866897
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
867898
if(TermBoostFlags_MissingLow & flags) {
868-
if(bMissing) {
869-
if(!bNominal) {
870-
pMissingBin = pBin;
871-
}
872-
*ppBin = pBin;
899+
if(bMissing && !bNominal) {
900+
pMissingBin = pBin;
901+
}
902+
} else if(TermBoostFlags_MissingHigh & flags) {
903+
if(bMissing && !bNominal) {
904+
pMissingBin = pBin;
905+
// the concept of TermBoostFlags_MissingHigh does not exist for nominals
873906
pBin = IndexBin(pBin, cBytesPerBin);
874-
++ppBin;
875907
}
876908
} else {
877909
if(bMissing) {
@@ -888,6 +920,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
888920
++ppBin;
889921
} while(pBinsEnd != pBin);
890922

923+
if(TermBoostFlags_MissingHigh & flags) {
924+
if(bMissing && !bNominal) {
925+
*ppBin = aBins;
926+
++ppBin;
927+
}
928+
}
929+
891930
if(bNominal) {
892931
std::sort(apBins,
893932
ppBin,
@@ -1072,15 +1111,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10721111
iDimension,
10731112
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian>* const*>(apBins),
10741113
nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade() : nullptr,
1075-
cSlices
1076-
#ifndef NDEBUG
1077-
,
1078-
cBins
1079-
#endif // NDEBUG
1080-
);
1114+
cSlices,
1115+
cBins);
10811116

1082-
EBM_ASSERT(!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
1083-
EBM_ASSERT(!bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);
1117+
EBM_ASSERT(
1118+
error != Error_None || !bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
1119+
EBM_ASSERT(
1120+
error != Error_None || !bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);
10841121

10851122
return error;
10861123
}

shared/libebm/tests/boosting_unusual_inputs.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2175,7 +2175,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
21752175
}
21762176

21772177
TEST_CASE("stress test, boosting") {
2178-
const double expected = 26758407585917.129;
2178+
const double expected = 26746562197367.172;
21792179

21802180
double validationMetricExact = RandomizedTesting(AccelerationFlags_NONE);
21812181
CHECK(validationMetricExact == expected);

0 commit comments

Comments
 (0)