@@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
117117 const size_t iDimension,
118118 const Bin<FloatMain, UIntMain, true , true , bHessian>* const * const apBins,
119119 const TreeNode<bHessian>* pMissingValueTreeNode,
120- const size_t cSlices
121- #ifndef NDEBUG
122- ,
123- const size_t cBins
124- #endif // NDEBUG
125- ) {
120+ const size_t cSlices,
121+ const size_t cBins) {
126122 LOG_0 (Trace_Verbose, " Entered Flatten" );
127123
128124 EBM_ASSERT (nullptr != pBoosterShell);
@@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178174 pUpdateScore = aUpdateScore;
179175
180176 if (bMissing) {
177+ EBM_ASSERT (2 <= cSlices); // no cuts if there was only missing bin
178+
181179 // always put a split on the missing bin
182180 *pSplit = 1 ;
183181 ++pSplit;
@@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
199197 const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
200198
201199 TreeNode<bHessian>* pParent = nullptr ;
200+ bool bDone = false ;
202201
203202 while (true ) {
204203 if (UNPREDICTABLE (pTreeNode->AFTER_IsSplit ())) {
@@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
253252 }
254253 EBM_ASSERT (!bNominal);
255254
256- // if !bNominal, check the bin above and below for order
257- EBM_ASSERT (apBins == ppBinLast || *(ppBinLast - 1 ) < *ppBinLast);
258- EBM_ASSERT (ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t {2 } : size_t {1 })) ||
259- *ppBinLast < *(ppBinLast + 1 ));
260-
261255 iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0 );
262256
263257 while (true ) { // not a real loop
@@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
267261 pMissingBin = pTreeNode->GetBin ();
268262 }
269263 if (1 == iEdge) {
264+ // this cut would isolate the missing bin, but we handle those scores separately
270265 break ;
271266 }
267+ } else if (TermBoostFlags_MissingHigh & flags) {
268+ ++iEdge; // missing is at index 0 in the model, so we are offset by one
269+ pMissingBin = pTreeNode->GetBin ();
270+ EBM_ASSERT (iEdge <= cBins + 1 );
271+ if (bDone) {
272+ // this cut would isolate the missing bin, but we handle those scores separately
273+ goto done;
274+ }
272275 }
273276 }
274277
@@ -316,10 +319,13 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
316319
317320 while (true ) {
318321 if (nullptr == pTreeNode) {
322+ done:;
319323 EBM_ASSERT (cSamplesTotalDebug == cSamplesExpectedDebug);
320324
321325 EBM_ASSERT (nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
322326 if (nullptr != pMissingBin) {
327+ EBM_ASSERT (bMissing);
328+
323329 FloatScore hess = static_cast <FloatCalc>(pMissingBin->GetWeight ());
324330 const auto * pGradientPair = pMissingBin->GetGradientPairs ();
325331 const auto * const pGradientPairEnd = pGradientPair + cScores;
@@ -341,6 +347,18 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
341347 } while (pGradientPairEnd != pGradientPair);
342348 }
343349
350+ EBM_ASSERT (bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
351+
352+ #ifndef NDEBUG
353+ UIntSplit prevDebug = 0 ;
354+ for (size_t iDebug = 0 ; iDebug < cSlices - 1 ; ++iDebug) {
355+ UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer (iDimension)[iDebug];
356+ EBM_ASSERT (prevDebug < curDebug);
357+ prevDebug = curDebug;
358+ }
359+ EBM_ASSERT (prevDebug < cBins);
360+ #endif
361+
344362 LOG_0 (Trace_Verbose, " Exited Flatten" );
345363 return Error_None;
346364 }
@@ -353,12 +371,21 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
353371 if (bMissing) {
354372 if (TermBoostFlags_MissingLow & flags) {
355373 if (1 == iEdge) {
374+ // this cut would isolate the missing bin, but missing already has a cut
375+ break ;
376+ }
377+ } else if (TermBoostFlags_MissingHigh & flags) {
378+ EBM_ASSERT (iEdge <= cBins);
379+ if (cBins == iEdge) {
380+ // This cut would isolate the missing bin, but missing already has a cut.
381+ // We still need to find the missing bin though in the tree, so continue.
356382 break ;
357383 }
358384 }
359385 }
360386
361387 EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
388+ EBM_ASSERT (pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
362389 *pSplit = static_cast <UIntSplit>(iEdge);
363390 ++pSplit;
364391
@@ -869,9 +896,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
869896 if (!bNominal) {
870897 pMissingBin = pBin;
871898 }
872- *ppBin = pBin;
899+ }
900+ } else if (TermBoostFlags_MissingHigh & flags) {
901+ if (bMissing) {
902+ if (!bNominal) {
903+ pMissingBin = pBin;
904+ }
873905 pBin = IndexBin (pBin, cBytesPerBin);
874- ++ppBin;
875906 }
876907 } else {
877908 if (bMissing) {
@@ -888,6 +919,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
888919 ++ppBin;
889920 } while (pBinsEnd != pBin);
890921
922+ if (TermBoostFlags_MissingHigh & flags) {
923+ if (bMissing) {
924+ *ppBin = aBins;
925+ ++ppBin;
926+ }
927+ }
928+
891929 if (bNominal) {
892930 std::sort (apBins,
893931 ppBin,
@@ -1072,12 +1110,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10721110 iDimension,
10731111 reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins),
10741112 nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade () : nullptr ,
1075- cSlices
1076- #ifndef NDEBUG
1077- ,
1078- cBins
1079- #endif // NDEBUG
1080- );
1113+ cSlices,
1114+ cBins);
10811115
10821116 EBM_ASSERT (!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate ()->GetCountSlices (iDimension));
10831117 EBM_ASSERT (!bMissing || *pBoosterShell->GetInnerTermUpdate ()->GetSplitPointer (iDimension) == 1 );
0 commit comments