@@ -117,12 +117,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
117
117
const size_t iDimension,
118
118
const Bin<FloatMain, UIntMain, true , true , bHessian>* const * const apBins,
119
119
const TreeNode<bHessian>* pMissingValueTreeNode,
120
- const size_t cSlices
121
- #ifndef NDEBUG
122
- ,
123
- const size_t cBins
124
- #endif // NDEBUG
125
- ) {
120
+ const size_t cSlices,
121
+ const size_t cBins) {
126
122
LOG_0 (Trace_Verbose, " Entered Flatten" );
127
123
128
124
EBM_ASSERT (nullptr != pBoosterShell);
@@ -178,6 +174,8 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178
174
pUpdateScore = aUpdateScore;
179
175
180
176
if (bMissing) {
177
+ EBM_ASSERT (2 <= cSlices); // no cuts if there was only missing bin
178
+
181
179
// always put a split on the missing bin
182
180
*pSplit = 1 ;
183
181
++pSplit;
@@ -199,6 +197,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
199
197
const bool bUpdateWithHessian = bHessian && !(TermBoostFlags_DisableNewtonUpdate & flags);
200
198
201
199
TreeNode<bHessian>* pParent = nullptr ;
200
+ bool bDone = false ;
202
201
203
202
while (true ) {
204
203
if (UNPREDICTABLE (pTreeNode->AFTER_IsSplit ())) {
@@ -253,11 +252,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
253
252
}
254
253
EBM_ASSERT (!bNominal);
255
254
256
- // if !bNominal, check the bin above and below for order
257
- EBM_ASSERT (apBins == ppBinLast || *(ppBinLast - 1 ) < *ppBinLast);
258
- EBM_ASSERT (ppBinLast == apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t {2 } : size_t {1 })) ||
259
- *ppBinLast < *(ppBinLast + 1 ));
260
-
261
255
iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0 );
262
256
263
257
while (true ) { // not a real loop
@@ -267,8 +261,17 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
267
261
pMissingBin = pTreeNode->GetBin ();
268
262
}
269
263
if (1 == iEdge) {
264
+ // this cut would isolate the missing bin, but we handle those scores separately
270
265
break ;
271
266
}
267
+ } else if (TermBoostFlags_MissingHigh & flags) {
268
+ ++iEdge; // missing is at index 0 in the model, so we are offset by one
269
+ pMissingBin = pTreeNode->GetBin ();
270
+ EBM_ASSERT (iEdge <= cBins + 1 );
271
+ if (bDone) {
272
+ // this cut would isolate the missing bin, but we handle those scores separately
273
+ goto done;
274
+ }
272
275
}
273
276
}
274
277
@@ -290,6 +293,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
290
293
deltaStepMax);
291
294
}
292
295
296
+ EBM_ASSERT (pUpdateScore < aUpdateScore + cScores * cSlices);
293
297
*pUpdateScore = static_cast <FloatScore>(updateScore);
294
298
++pUpdateScore;
295
299
@@ -316,10 +320,27 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
316
320
317
321
while (true ) {
318
322
if (nullptr == pTreeNode) {
323
+ done:;
319
324
EBM_ASSERT (cSamplesTotalDebug == cSamplesExpectedDebug);
320
325
326
+ EBM_ASSERT (bNominal || pUpdateScore == aUpdateScore + cScores * cSlices);
327
+
328
+ EBM_ASSERT (bNominal || pSplit == cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
329
+
330
+ #ifndef NDEBUG
331
+ UIntSplit prevDebug = 0 ;
332
+ for (size_t iDebug = 0 ; iDebug < cSlices - 1 ; ++iDebug) {
333
+ UIntSplit curDebug = pInnerTermUpdate->GetSplitPointer (iDimension)[iDebug];
334
+ EBM_ASSERT (prevDebug < curDebug);
335
+ prevDebug = curDebug;
336
+ }
337
+ EBM_ASSERT (prevDebug < cBins);
338
+ #endif
339
+
321
340
EBM_ASSERT (nullptr == pMissingValueTreeNode || nullptr != pMissingBin);
322
341
if (nullptr != pMissingBin) {
342
+ EBM_ASSERT (bMissing);
343
+
323
344
FloatScore hess = static_cast <FloatCalc>(pMissingBin->GetWeight ());
324
345
const auto * pGradientPair = pMissingBin->GetGradientPairs ();
325
346
const auto * const pGradientPairEnd = pGradientPair + cScores;
@@ -353,12 +374,22 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
353
374
if (bMissing) {
354
375
if (TermBoostFlags_MissingLow & flags) {
355
376
if (1 == iEdge) {
377
+ // this cut would isolate the missing bin, but missing already has a cut
378
+ break ;
379
+ }
380
+ } else if (TermBoostFlags_MissingHigh & flags) {
381
+ EBM_ASSERT (iEdge <= cBins);
382
+ if (cBins == iEdge) {
383
+ // This cut would isolate the missing bin, but missing already has a cut.
384
+ // We still need to find the missing bin though in the tree, so continue.
385
+ bDone = true ;
356
386
break ;
357
387
}
358
388
}
359
389
}
360
390
361
391
EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
392
+ EBM_ASSERT (pSplit < cSlices - 1 + pInnerTermUpdate->GetSplitPointer (iDimension));
362
393
*pSplit = static_cast <UIntSplit>(iEdge);
363
394
++pSplit;
364
395
@@ -865,13 +896,14 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
865
896
866
897
const TreeNode<bHessian, GetArrayScores (cCompilerScores)>* pMissingValueTreeNode = nullptr ;
867
898
if (TermBoostFlags_MissingLow & flags) {
868
- if (bMissing) {
869
- if (!bNominal) {
870
- pMissingBin = pBin;
871
- }
872
- *ppBin = pBin;
899
+ if (bMissing && !bNominal) {
900
+ pMissingBin = pBin;
901
+ }
902
+ } else if (TermBoostFlags_MissingHigh & flags) {
903
+ if (bMissing && !bNominal) {
904
+ pMissingBin = pBin;
905
+ // the concept of TermBoostFlags_MissingHigh does not exist for nominals
873
906
pBin = IndexBin (pBin, cBytesPerBin);
874
- ++ppBin;
875
907
}
876
908
} else {
877
909
if (bMissing) {
@@ -888,6 +920,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
888
920
++ppBin;
889
921
} while (pBinsEnd != pBin);
890
922
923
+ if (TermBoostFlags_MissingHigh & flags) {
924
+ if (bMissing && !bNominal) {
925
+ *ppBin = aBins;
926
+ ++ppBin;
927
+ }
928
+ }
929
+
891
930
if (bNominal) {
892
931
std::sort (apBins,
893
932
ppBin,
@@ -1072,15 +1111,13 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
1072
1111
iDimension,
1073
1112
reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian>* const *>(apBins),
1074
1113
nullptr != pMissingValueTreeNode ? pMissingValueTreeNode->Downgrade () : nullptr ,
1075
- cSlices
1076
- #ifndef NDEBUG
1077
- ,
1078
- cBins
1079
- #endif // NDEBUG
1080
- );
1114
+ cSlices,
1115
+ cBins);
1081
1116
1082
- EBM_ASSERT (!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate ()->GetCountSlices (iDimension));
1083
- EBM_ASSERT (!bMissing || *pBoosterShell->GetInnerTermUpdate ()->GetSplitPointer (iDimension) == 1 );
1117
+ EBM_ASSERT (
1118
+ error != Error_None || !bMissing || 2 <= pBoosterShell->GetInnerTermUpdate ()->GetCountSlices (iDimension));
1119
+ EBM_ASSERT (
1120
+ error != Error_None || !bMissing || *pBoosterShell->GetInnerTermUpdate ()->GetSplitPointer (iDimension) == 1 );
1084
1121
1085
1122
return error;
1086
1123
}
0 commit comments