@@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
117117 const size_t iDimension,
118118 const Bin<FloatMain, UIntMain, true , true , bHessian>* const * const apBins,
119119 const TreeNode<bHessian>* pMissingValueTreeNode,
120- const size_t cSlices
121- #ifndef NDEBUG
122- ,
123- const size_t cBins
124- #endif // NDEBUG
125- ) {
120+ const size_t cSlices,
121+ const size_t cBins) {
126122 LOG_0 (Trace_Verbose, " Entered Flatten" );
127123
128124 EBM_ASSERT (nullptr != pBoosterShell);
@@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178174 pUpdateScore = aUpdateScore;
179175
180176 if (bMissing) {
177+ EBM_ASSERT (2 <= cSlices); // no cuts if there was only missing bin
178+
181179 // always put a split on the missing bin
182180 *pSplit = 1 ;
183181 ++pSplit;
@@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
199197 const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
200198
201199 TreeNode<bHessian>* pParent = nullptr ;
200+ bool bDone = false ;
202201
203202 while (true ) {
204203 if (UNPREDICTABLE (pTreeNode->AFTER_IsSplit ())) {
@@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
253252 }
254253 EBM_ASSERT (!bNominal);
255254
256- // if !bNominal, check the bin above and below for order
257- EBM_ASSERT (apBins == ppBinLast || *(ppBinLast - 1 ) < *ppBinLast);
258- EBM_ASSERT (ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t {2 } : size_t {1 })) ||
259- *ppBinLast < *(ppBinLast + 1 ));
260-
261255 iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0 );
262256
263257 while (true ) { // not a real loop
@@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
267261 pMissingBin = pTreeNode->GetBin ();
268262 }
269263 if (1 == iEdge) {
264+ // this cut would isolate the missing bin, but we handle those scores separately
270265 break ;
271266 }
267+ } else if (TermBoostFlags_MissingHigh & flags) {
268+ ++iEdge; // missing is at index 0 in the model, so we are offset by one
269+ pMissingBin = pTreeNode->GetBin ();
270+ EBM_ASSERT (iEdge <= cBins + 1 );
271+ if (bDone) {
272+ // this cut would isolate the missing bin, but we handle those scores separately
273+ goto done;
274+ }
272275 }
273276 }
274277
@@ -290,6 +293,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
290293 deltaStepMax);
291294 }
292295
296+ EBM_ASSERT (pUpdateScore < aUpdateScore + cScores * cSlices);
293297 *pUpdateScore = static_cast <FloatScore>(updateScore);
294298 ++pUpdateScore;
295299
@@ -316,10 +320,27 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
316320
317321 while (true ) {
318322 if (nullptr == pTreeNode) {
323+ done:;
319324 EBM_ASSERT (cSamplesTotalDebug == cSamplesExpectedDebug);
320325
326+ EBM_ASSERT (bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);
327+
328+ EBM_ASSERT (bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
329+
330+ #ifndef NDEBUG
331+ UIntSplit prevDebug = 0 ;
332+ for (size_t iDebug = 0 ; iDebug < cSlices - 1 ; ++iDebug) {
333+ UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer (iDimension)[iDebug];
334+ EBM_ASSERT (prevDebug < curDebug);
335+ prevDebug = curDebug;
336+ }
337+ EBM_ASSERT (prevDebug < cBins);
338+ #endif
339+
321340 EBM_ASSERT (nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
322341 if (nullptr != pMissingBin) {
342+ EBM_ASSERT (bMissing);
343+
323344 FloatScore hess = static_cast <FloatCalc>(pMissingBin->GetWeight ());
324345 const auto * pGradientPair = pMissingBin->GetGradientPairs ();
325346 const auto * const pGradientPairEnd = pGradientPair + cScores;
@@ -353,12 +374,22 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
353374 if (bMissing) {
354375 if (TermBoostFlags_MissingLow & flags) {
355376 if (1 == iEdge) {
377+ // this cut would isolate the missing bin, but missing already has a cut
378+ break ;
379+ }
380+ } else if (TermBoostFlags_MissingHigh & flags) {
381+ EBM_ASSERT (iEdge <= cBins);
382+ if (cBins == iEdge) {
383+ // This cut would isolate the missing bin, but missing already has a cut.
384+ // We still need to find the missing bin though in the tree, so continue.
385+ bDone = true ;
356386 break ;
357387 }
358388 }
359389 }
360390
361391 EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
392+ EBM_ASSERT (pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
362393 *pSplit = static_cast <UIntSplit>(iEdge);
363394 ++pSplit;
364395
@@ -865,13 +896,14 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
865896
866897 const TreeNode<bHessian, GetArrayScores (cCompilerScores)>* pMissingValueTreeNode = nullptr ;
867898 if (TermBoostFlags_MissingLow & flags) {
868- if (bMissing) {
869- if (!bNominal) {
870- pMissingBin = pBin;
871- }
872- *ppBin = pBin;
899+ if (bMissing && !bNominal) {
900+ pMissingBin = pBin;
901+ }
902+ } else if (TermBoostFlags_MissingHigh & flags) {
903+ if (bMissing && !bNominal) {
904+ pMissingBin = pBin;
905+ // the concept of TermBoostFlags_MissingHigh does not exist for nominals
873906 pBin = IndexBin (pBin, cBytesPerBin);
874- ++ppBin;
875907 }
876908 } else {
877909 if (bMissing) {
@@ -888,6 +920,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
888920 ++ppBin;
889921 } while (pBinsEnd != pBin);
890922
923+ if (TermBoostFlags_MissingHigh & flags) {
924+ if (bMissing && !bNominal) {
925+ *ppBin = aBins;
926+ ++ppBin;
927+ }
928+ }
929+
891930 if (bNominal) {
892931 std::sort (apBins,
893932 ppBin,
@@ -1072,15 +1111,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10721111 iDimension,
10731112 reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins),
10741113 nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade () : nullptr ,
1075- cSlices
1076- #ifndef NDEBUG
1077- ,
1078- cBins
1079- #endif // NDEBUG
1080- );
1114+ cSlices,
1115+ cBins);
10811116
1082- EBM_ASSERT (!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate ()->GetCountSlices (iDimension));
1083- EBM_ASSERT (!bMissing || *pBoosterShell->GetInnerTermUpdate ()->GetSplitPointer (iDimension) == 1 );
1117+ EBM_ASSERT (
1118+ error != Error_None || !bMissing || 2 <= pBoosterShell->GetInnerTermUpdate ()->GetCountSlices (iDimension));
1119+ EBM_ASSERT (
1120+ error != Error_None || !bMissing || *pBoosterShell->GetInnerTermUpdate ()->GetSplitPointer (iDimension) == 1 );
10841121
10851122 return error;
10861123 }
0 commit comments