@@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
108
108
// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
109
109
template <bool bHessian>
110
110
static ErrorEbm Flatten (BoosterShell* const pBoosterShell,
111
- bool bExtraMissingCut ,
111
+ const bool bMissing ,
112
112
const bool bNominal,
113
113
const TermBoostFlags flags,
114
114
const FloatCalc regAlpha,
@@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
132
132
EBM_ASSERT (2 <= cBins);
133
133
EBM_ASSERT (cSlices <= cBins);
134
134
EBM_ASSERT (!bNominal || cSlices == cBins);
135
- EBM_ASSERT (!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere
136
135
137
136
ErrorEbm error;
138
137
@@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178
177
} else {
179
178
pUpdateScore = aUpdateScore;
180
179
181
- if (nullptr != pMissingValueTreeNode || bExtraMissingCut ) {
180
+ if (bMissing ) {
182
181
// always put a split on the missing bin
183
182
*pSplit = 1 ;
184
183
++pSplit;
@@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
239
238
}
240
239
}
241
240
242
- if (bExtraMissingCut) {
243
- EBM_ASSERT (!bNominal); // for Nominal we cut everywhere
244
- if (TermBoostFlags_MissingLow & flags) {
245
- if (nullptr == pMissingBin) {
246
- pMissingBin = pTreeNode->GetBin ();
247
- }
248
- } else {
249
- EBM_ASSERT (TermBoostFlags_MissingHigh & flags);
250
- pMissingBin = pTreeNode->GetBin ();
251
- }
252
- }
253
-
254
241
EBM_ASSERT (apBins <= ppBinLast);
255
242
EBM_ASSERT (ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t {1 } : size_t {0 })));
256
243
@@ -273,41 +260,56 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
273
260
274
261
iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0 );
275
262
276
- while (true ) {
277
- iScore = 0 ;
278
- do {
279
- FloatCalc updateScore;
280
- if (bUpdateWithHessian) {
281
- updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
282
- static_cast <FloatCalc>(aGradientPair[iScore].GetHess ()),
283
- regAlpha,
284
- regLambda,
285
- deltaStepMax);
286
- } else {
287
- updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
288
- static_cast <FloatCalc>(pTreeNode->GetBin ()->GetWeight ()),
289
- regAlpha,
290
- regLambda,
291
- deltaStepMax);
263
+ while (true ) { // not a real loop
264
+ if (bMissing) {
265
+ if (TermBoostFlags_MissingLow & flags) {
266
+ if (nullptr == pMissingBin) {
267
+ pMissingBin = pTreeNode->GetBin ();
268
+ }
269
+ if (1 == iEdge) {
270
+ break ;
271
+ }
292
272
}
273
+ }
293
274
294
- *pUpdateScore = static_cast <FloatScore>(updateScore);
295
- ++pUpdateScore;
275
+ while (true ) {
276
+ iScore = 0 ;
277
+ do {
278
+ FloatCalc updateScore;
279
+ if (bUpdateWithHessian) {
280
+ updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
281
+ static_cast <FloatCalc>(aGradientPair[iScore].GetHess ()),
282
+ regAlpha,
283
+ regLambda,
284
+ deltaStepMax);
285
+ } else {
286
+ updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
287
+ static_cast <FloatCalc>(pTreeNode->GetBin ()->GetWeight ()),
288
+ regAlpha,
289
+ regLambda,
290
+ deltaStepMax);
291
+ }
296
292
297
- ++iScore;
298
- } while (cScores != iScore);
299
- if (nullptr == ppBinCur) {
300
- break ;
301
- }
302
- EBM_ASSERT (bNominal);
303
- ++ppBinCur;
304
- if (ppBinLast < ppBinCur) {
305
- break ;
293
+ *pUpdateScore = static_cast <FloatScore>(updateScore);
294
+ ++pUpdateScore;
295
+
296
+ ++iScore;
297
+ } while (cScores != iScore);
298
+ if (nullptr == ppBinCur) {
299
+ break ;
300
+ }
301
+ EBM_ASSERT (bNominal);
302
+ ++ppBinCur;
303
+ if (ppBinLast < ppBinCur) {
304
+ break ;
305
+ }
306
+ determine_bin:;
307
+ const auto * const pBinCur = *ppBinCur;
308
+ const size_t iBin = CountBins (pBinCur, aBins, cBytesPerBin);
309
+ pUpdateScore = aUpdateScore + iBin * cScores;
306
310
}
307
- determine_bin:;
308
- const auto * const pBinCur = *ppBinCur;
309
- const size_t iBin = CountBins (pBinCur, aBins, cBytesPerBin);
310
- pUpdateScore = aUpdateScore + iBin * cScores;
311
+
312
+ break ;
311
313
}
312
314
313
315
pTreeNode = pParent;
@@ -345,9 +347,23 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
345
347
if (!pTreeNode->DECONSTRUCT_IsRightChildTraversal ()) {
346
348
// we checked earlier that countBins could be converted to a UIntSplit
347
349
if (nullptr == ppBinCur) {
348
- EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
349
- *pSplit = static_cast <UIntSplit>(iEdge);
350
- ++pSplit;
350
+ EBM_ASSERT (!bNominal);
351
+
352
+ while (true ) { // not a real loop
353
+ if (bMissing) {
354
+ if (TermBoostFlags_MissingLow & flags) {
355
+ if (1 == iEdge) {
356
+ break ;
357
+ }
358
+ }
359
+ }
360
+
361
+ EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
362
+ *pSplit = static_cast <UIntSplit>(iEdge);
363
+ ++pSplit;
364
+
365
+ break ;
366
+ }
351
367
}
352
368
pParent = pTreeNode;
353
369
pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark (cBytesPerTreeNode);
@@ -832,7 +848,10 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
832
848
auto * const aBins =
833
849
pBoosterShell->GetBoostingMainBins ()
834
850
->Specialize <FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>();
835
- auto * const pBinsEnd = IndexBin (aBins, cBytesPerBin * cBins);
851
+ auto * pBinsEnd = IndexBin (aBins, cBytesPerBin * cBins);
852
+
853
+ SumAllBins<bHessian, cCompilerScores>(
854
+ pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin ());
836
855
837
856
const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>** const apBins =
838
857
reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>**>(
@@ -844,7 +863,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
844
863
const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>* pMissingBin = nullptr ;
845
864
bool bMissingIsolated = false ;
846
865
847
- size_t cBinsAdjusted = cBins;
848
866
const TreeNode<bHessian, GetArrayScores (cCompilerScores)>* pMissingValueTreeNode = nullptr ;
849
867
if (TermBoostFlags_MissingLow & flags) {
850
868
if (bMissing) {
@@ -861,32 +879,25 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
861
879
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
862
880
// region.
863
881
pBin = IndexBin (pBin, cBytesPerBin);
864
- --cBinsAdjusted;
865
882
}
866
883
}
867
884
868
- const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>** ppBinsEnd =
869
- apBins + cBinsAdjusted;
870
-
871
885
do {
872
886
*ppBin = pBin;
873
887
pBin = IndexBin (pBin, cBytesPerBin);
874
888
++ppBin;
875
- } while (ppBinsEnd != ppBin );
889
+ } while (pBinsEnd != pBin );
876
890
877
891
if (bNominal) {
878
892
std::sort (apBins,
879
- ppBinsEnd ,
893
+ ppBin ,
880
894
CompareBin<bHessian, cCompilerScores>(
881
895
!(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
882
896
}
883
897
884
898
pRootTreeNode->BEFORE_SetBinFirst (apBins);
885
- pRootTreeNode->BEFORE_SetBinLast (ppBinsEnd - 1 );
886
- ASSERT_BIN_OK (cBytesPerBin, *(ppBinsEnd - 1 ), pBoosterShell->GetDebugMainBinsEnd ());
887
-
888
- SumAllBins<bHessian, cCompilerScores>(
889
- pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin ());
899
+ pRootTreeNode->BEFORE_SetBinLast (ppBin - 1 );
900
+ ASSERT_BIN_OK (cBytesPerBin, *(ppBin - 1 ), pBoosterShell->GetDebugMainBinsEnd ());
890
901
891
902
EBM_ASSERT (!IsOverflowTreeNodeSize (bHessian, cScores));
892
903
const size_t cBytesPerTreeNode = GetTreeNodeSize (bHessian, cScores);
@@ -1040,21 +1051,19 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
1040
1051
*pTotalGain = static_cast <double >(totalGain);
1041
1052
1042
1053
size_t cSlices = cSplitsMax - cSplitsRemaining + 1 ;
1043
- bool bExtraMissingCut = false ;
1044
1054
if (nullptr != pMissingValueTreeNode) {
1045
1055
EBM_ASSERT (nullptr == pMissingBin);
1046
1056
++cSlices;
1047
1057
} else {
1048
1058
if (nullptr != pMissingBin && !bMissingIsolated) {
1049
- bExtraMissingCut = true ;
1050
1059
++cSlices;
1051
1060
}
1052
1061
}
1053
1062
if (bNominal) {
1054
1063
cSlices = cBins;
1055
1064
}
1056
1065
const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
1057
- bExtraMissingCut ,
1066
+ bMissing ,
1058
1067
bNominal,
1059
1068
flags,
1060
1069
regAlpha,
0 commit comments