@@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
108108// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
109109template <bool bHessian>
110110static ErrorEbm Flatten (BoosterShell* const pBoosterShell,
111- bool bExtraMissingCut ,
111+ const bool bMissing ,
112112 const bool bNominal,
113113 const TermBoostFlags flags,
114114 const FloatCalc regAlpha,
@@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
132132 EBM_ASSERT (2 <= cBins);
133133 EBM_ASSERT (cSlices <= cBins);
134134 EBM_ASSERT (!bNominal || cSlices == cBins);
135- EBM_ASSERT (!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere
136135
137136 ErrorEbm error;
138137
@@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178177 } else {
179178 pUpdateScore = aUpdateScore;
180179
181- if (nullptr != pMissingValueTreeNode || bExtraMissingCut ) {
180+ if (bMissing ) {
182181 // always put a split on the missing bin
183182 *pSplit = 1 ;
184183 ++pSplit;
@@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
239238 }
240239 }
241240
242- if (bExtraMissingCut) {
243- EBM_ASSERT (!bNominal); // for Nominal we cut everywhere
244- if (TermBoostFlags_MissingLow & flags) {
245- if (nullptr == pMissingBin) {
246- pMissingBin = pTreeNode->GetBin ();
247- }
248- } else {
249- EBM_ASSERT (TermBoostFlags_MissingHigh & flags);
250- pMissingBin = pTreeNode->GetBin ();
251- }
252- }
253-
254241 EBM_ASSERT (apBins <= ppBinLast);
255242 EBM_ASSERT (ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t {1 } : size_t {0 })));
256243
@@ -273,41 +260,56 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
273260
274261 iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0 );
275262
276- while (true ) {
277- iScore = 0 ;
278- do {
279- FloatCalc updateScore;
280- if (bUpdateWithHessian) {
281- updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
282- static_cast <FloatCalc>(aGradientPair[iScore].GetHess ()),
283- regAlpha,
284- regLambda,
285- deltaStepMax);
286- } else {
287- updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
288- static_cast <FloatCalc>(pTreeNode->GetBin ()->GetWeight ()),
289- regAlpha,
290- regLambda,
291- deltaStepMax);
263+ while (true ) { // not a real loop
264+ if (bMissing) {
265+ if (TermBoostFlags_MissingLow & flags) {
266+ if (nullptr == pMissingBin) {
267+ pMissingBin = pTreeNode->GetBin ();
268+ }
269+ if (1 == iEdge) {
270+ break ;
271+ }
292272 }
273+ }
293274
294- *pUpdateScore = static_cast <FloatScore>(updateScore);
295- ++pUpdateScore;
275+ while (true ) {
276+ iScore = 0 ;
277+ do {
278+ FloatCalc updateScore;
279+ if (bUpdateWithHessian) {
280+ updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
281+ static_cast <FloatCalc>(aGradientPair[iScore].GetHess ()),
282+ regAlpha,
283+ regLambda,
284+ deltaStepMax);
285+ } else {
286+ updateScore = -CalcNegUpdate<true >(static_cast <FloatCalc>(aGradientPair[iScore].m_sumGradients ),
287+ static_cast <FloatCalc>(pTreeNode->GetBin ()->GetWeight ()),
288+ regAlpha,
289+ regLambda,
290+ deltaStepMax);
291+ }
296292
297- ++iScore;
298- } while (cScores != iScore);
299- if (nullptr == ppBinCur) {
300- break ;
301- }
302- EBM_ASSERT (bNominal);
303- ++ppBinCur;
304- if (ppBinLast < ppBinCur) {
305- break ;
293+ *pUpdateScore = static_cast <FloatScore>(updateScore);
294+ ++pUpdateScore;
295+
296+ ++iScore;
297+ } while (cScores != iScore);
298+ if (nullptr == ppBinCur) {
299+ break ;
300+ }
301+ EBM_ASSERT (bNominal);
302+ ++ppBinCur;
303+ if (ppBinLast < ppBinCur) {
304+ break ;
305+ }
306+ determine_bin:;
307+ const auto * const pBinCur = *ppBinCur;
308+ const size_t iBin = CountBins (pBinCur, aBins, cBytesPerBin);
309+ pUpdateScore = aUpdateScore + iBin * cScores;
306310 }
307- determine_bin:;
308- const auto * const pBinCur = *ppBinCur;
309- const size_t iBin = CountBins (pBinCur, aBins, cBytesPerBin);
310- pUpdateScore = aUpdateScore + iBin * cScores;
311+
312+ break ;
311313 }
312314
313315 pTreeNode = pParent;
@@ -345,9 +347,23 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
345347 if (!pTreeNode->DECONSTRUCT_IsRightChildTraversal ()) {
346348 // we checked earlier that countBins could be converted to a UIntSplit
347349 if (nullptr == ppBinCur) {
348- EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
349- *pSplit = static_cast <UIntSplit>(iEdge);
350- ++pSplit;
350+ EBM_ASSERT (!bNominal);
351+
352+ while (true ) { // not a real loop
353+ if (bMissing) {
354+ if (TermBoostFlags_MissingLow & flags) {
355+ if (1 == iEdge) {
356+ break ;
357+ }
358+ }
359+ }
360+
361+ EBM_ASSERT (!IsConvertError<UIntSplit>(iEdge));
362+ *pSplit = static_cast <UIntSplit>(iEdge);
363+ ++pSplit;
364+
365+ break ;
366+ }
351367 }
352368 pParent = pTreeNode;
353369 pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark (cBytesPerTreeNode);
@@ -832,7 +848,10 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
832848 auto * const aBins =
833849 pBoosterShell->GetBoostingMainBins ()
834850 ->Specialize <FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>();
835- auto * const pBinsEnd = IndexBin (aBins, cBytesPerBin * cBins);
851+ auto * pBinsEnd = IndexBin (aBins, cBytesPerBin * cBins);
852+
853+ SumAllBins<bHessian, cCompilerScores>(
854+ pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin ());
836855
837856 const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>** const apBins =
838857 reinterpret_cast <const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>**>(
@@ -844,7 +863,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
844863 const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>* pMissingBin = nullptr ;
845864 bool bMissingIsolated = false ;
846865
847- size_t cBinsAdjusted = cBins;
848866 const TreeNode<bHessian, GetArrayScores (cCompilerScores)>* pMissingValueTreeNode = nullptr ;
849867 if (TermBoostFlags_MissingLow & flags) {
850868 if (bMissing) {
@@ -861,32 +879,25 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
861879 // Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
862880 // region.
863881 pBin = IndexBin (pBin, cBytesPerBin);
864- --cBinsAdjusted;
865882 }
866883 }
867884
868- const Bin<FloatMain, UIntMain, true , true , bHessian, GetArrayScores (cCompilerScores)>** ppBinsEnd =
869- apBins + cBinsAdjusted;
870-
871885 do {
872886 *ppBin = pBin;
873887 pBin = IndexBin (pBin, cBytesPerBin);
874888 ++ppBin;
875- } while (ppBinsEnd != ppBin );
889+ } while (pBinsEnd != pBin );
876890
877891 if (bNominal) {
878892 std::sort (apBins,
879- ppBinsEnd ,
893+ ppBin ,
880894 CompareBin<bHessian, cCompilerScores>(
881895 !(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
882896 }
883897
884898 pRootTreeNode->BEFORE_SetBinFirst (apBins);
885- pRootTreeNode->BEFORE_SetBinLast (ppBinsEnd - 1 );
886- ASSERT_BIN_OK (cBytesPerBin, *(ppBinsEnd - 1 ), pBoosterShell->GetDebugMainBinsEnd ());
887-
888- SumAllBins<bHessian, cCompilerScores>(
889- pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin ());
899+ pRootTreeNode->BEFORE_SetBinLast (ppBin - 1 );
900+ ASSERT_BIN_OK (cBytesPerBin, *(ppBin - 1 ), pBoosterShell->GetDebugMainBinsEnd ());
890901
891902 EBM_ASSERT (!IsOverflowTreeNodeSize (bHessian, cScores));
892903 const size_t cBytesPerTreeNode = GetTreeNodeSize (bHessian, cScores);
@@ -1040,21 +1051,19 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10401051 *pTotalGain = static_cast <double >(totalGain);
10411052
10421053 size_t cSlices = cSplitsMax - cSplitsRemaining + 1 ;
1043- bool bExtraMissingCut = false ;
10441054 if (nullptr != pMissingValueTreeNode) {
10451055 EBM_ASSERT (nullptr == pMissingBin);
10461056 ++cSlices;
10471057 } else {
10481058 if (nullptr != pMissingBin && !bMissingIsolated) {
1049- bExtraMissingCut = true ;
10501059 ++cSlices;
10511060 }
10521061 }
10531062 if (bNominal) {
10541063 cSlices = cBins;
10551064 }
10561065 const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
1057- bExtraMissingCut ,
1066+ bMissing ,
10581067 bNominal,
10591068 flags,
10601069 regAlpha,
0 commit comments