Skip to content

Commit 42875bc

Browse files
committed
change handling of the cuts for the missing bin to the beginning of the flattening process
1 parent 7cea7a3 commit 42875bc

File tree

1 file changed

+74
-65
lines changed

1 file changed

+74
-65
lines changed

shared/libebm/PartitionOneDimensionalBoosting.cpp

+74-65
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
108108
// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
109109
template<bool bHessian>
110110
static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
111-
bool bExtraMissingCut,
111+
const bool bMissing,
112112
const bool bNominal,
113113
const TermBoostFlags flags,
114114
const FloatCalc regAlpha,
@@ -132,7 +132,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
132132
EBM_ASSERT(2 <= cBins);
133133
EBM_ASSERT(cSlices <= cBins);
134134
EBM_ASSERT(!bNominal || cSlices == cBins);
135-
EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere
136135

137136
ErrorEbm error;
138137

@@ -178,7 +177,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
178177
} else {
179178
pUpdateScore = aUpdateScore;
180179

181-
if(nullptr != pMissingValueTreeNode || bExtraMissingCut) {
180+
if(bMissing) {
182181
// always put a split on the missing bin
183182
*pSplit = 1;
184183
++pSplit;
@@ -239,18 +238,6 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
239238
}
240239
}
241240

242-
if(bExtraMissingCut) {
243-
EBM_ASSERT(!bNominal); // for Nominal we cut everywhere
244-
if(TermBoostFlags_MissingLow & flags) {
245-
if(nullptr == pMissingBin) {
246-
pMissingBin = pTreeNode->GetBin();
247-
}
248-
} else {
249-
EBM_ASSERT(TermBoostFlags_MissingHigh & flags);
250-
pMissingBin = pTreeNode->GetBin();
251-
}
252-
}
253-
254241
EBM_ASSERT(apBins <= ppBinLast);
255242
EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0})));
256243

@@ -273,41 +260,56 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
273260

274261
iEdge = ppBinLast - apBins + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);
275262

276-
while(true) {
277-
iScore = 0;
278-
do {
279-
FloatCalc updateScore;
280-
if(bUpdateWithHessian) {
281-
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
282-
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
283-
regAlpha,
284-
regLambda,
285-
deltaStepMax);
286-
} else {
287-
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
288-
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
289-
regAlpha,
290-
regLambda,
291-
deltaStepMax);
263+
while(true) { // not a real loop
264+
if(bMissing) {
265+
if(TermBoostFlags_MissingLow & flags) {
266+
if(nullptr == pMissingBin) {
267+
pMissingBin = pTreeNode->GetBin();
268+
}
269+
if(1 == iEdge) {
270+
break;
271+
}
292272
}
273+
}
293274

294-
*pUpdateScore = static_cast<FloatScore>(updateScore);
295-
++pUpdateScore;
275+
while(true) {
276+
iScore = 0;
277+
do {
278+
FloatCalc updateScore;
279+
if(bUpdateWithHessian) {
280+
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
281+
static_cast<FloatCalc>(aGradientPair[iScore].GetHess()),
282+
regAlpha,
283+
regLambda,
284+
deltaStepMax);
285+
} else {
286+
updateScore = -CalcNegUpdate<true>(static_cast<FloatCalc>(aGradientPair[iScore].m_sumGradients),
287+
static_cast<FloatCalc>(pTreeNode->GetBin()->GetWeight()),
288+
regAlpha,
289+
regLambda,
290+
deltaStepMax);
291+
}
296292

297-
++iScore;
298-
} while(cScores != iScore);
299-
if(nullptr == ppBinCur) {
300-
break;
301-
}
302-
EBM_ASSERT(bNominal);
303-
++ppBinCur;
304-
if(ppBinLast < ppBinCur) {
305-
break;
293+
*pUpdateScore = static_cast<FloatScore>(updateScore);
294+
++pUpdateScore;
295+
296+
++iScore;
297+
} while(cScores != iScore);
298+
if(nullptr == ppBinCur) {
299+
break;
300+
}
301+
EBM_ASSERT(bNominal);
302+
++ppBinCur;
303+
if(ppBinLast < ppBinCur) {
304+
break;
305+
}
306+
determine_bin:;
307+
const auto* const pBinCur = *ppBinCur;
308+
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
309+
pUpdateScore = aUpdateScore + iBin * cScores;
306310
}
307-
determine_bin:;
308-
const auto* const pBinCur = *ppBinCur;
309-
const size_t iBin = CountBins(pBinCur, aBins, cBytesPerBin);
310-
pUpdateScore = aUpdateScore + iBin * cScores;
311+
312+
break;
311313
}
312314

313315
pTreeNode = pParent;
@@ -345,9 +347,23 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
345347
if(!pTreeNode->DECONSTRUCT_IsRightChildTraversal()) {
346348
// we checked earlier that countBins could be converted to a UIntSplit
347349
if(nullptr == ppBinCur) {
348-
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
349-
*pSplit = static_cast<UIntSplit>(iEdge);
350-
++pSplit;
350+
EBM_ASSERT(!bNominal);
351+
352+
while(true) { // not a real loop
353+
if(bMissing) {
354+
if(TermBoostFlags_MissingLow & flags) {
355+
if(1 == iEdge) {
356+
break;
357+
}
358+
}
359+
}
360+
361+
EBM_ASSERT(!IsConvertError<UIntSplit>(iEdge));
362+
*pSplit = static_cast<UIntSplit>(iEdge);
363+
++pSplit;
364+
365+
break;
366+
}
351367
}
352368
pParent = pTreeNode;
353369
pTreeNode = pTreeNode->DECONSTRUCT_TraverseRightAndMark(cBytesPerTreeNode);
@@ -832,7 +848,10 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
832848
auto* const aBins =
833849
pBoosterShell->GetBoostingMainBins()
834850
->Specialize<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>();
835-
auto* const pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);
851+
auto* pBinsEnd = IndexBin(aBins, cBytesPerBin * cBins);
852+
853+
SumAllBins<bHessian, cCompilerScores>(
854+
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());
836855

837856
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** const apBins =
838857
reinterpret_cast<const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>**>(
@@ -844,7 +863,6 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
844863
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* pMissingBin = nullptr;
845864
bool bMissingIsolated = false;
846865

847-
size_t cBinsAdjusted = cBins;
848866
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
849867
if(TermBoostFlags_MissingLow & flags) {
850868
if(bMissing) {
@@ -861,32 +879,25 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
861879
// Skip the missing bin in the pointer to pointer mapping since it will not be part of the continuous
862880
// region.
863881
pBin = IndexBin(pBin, cBytesPerBin);
864-
--cBinsAdjusted;
865882
}
866883
}
867884

868-
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** ppBinsEnd =
869-
apBins + cBinsAdjusted;
870-
871885
do {
872886
*ppBin = pBin;
873887
pBin = IndexBin(pBin, cBytesPerBin);
874888
++ppBin;
875-
} while(ppBinsEnd != ppBin);
889+
} while(pBinsEnd != pBin);
876890

877891
if(bNominal) {
878892
std::sort(apBins,
879-
ppBinsEnd,
893+
ppBin,
880894
CompareBin<bHessian, cCompilerScores>(
881895
!(TermBoostFlags_DisableNewtonUpdate & flags), categoricalSmoothing));
882896
}
883897

884898
pRootTreeNode->BEFORE_SetBinFirst(apBins);
885-
pRootTreeNode->BEFORE_SetBinLast(ppBinsEnd - 1);
886-
ASSERT_BIN_OK(cBytesPerBin, *(ppBinsEnd - 1), pBoosterShell->GetDebugMainBinsEnd());
887-
888-
SumAllBins<bHessian, cCompilerScores>(
889-
pBoosterShell, pBinsEnd, cSamplesTotal, weightTotal, pRootTreeNode->GetBin());
899+
pRootTreeNode->BEFORE_SetBinLast(ppBin - 1);
900+
ASSERT_BIN_OK(cBytesPerBin, *(ppBin - 1), pBoosterShell->GetDebugMainBinsEnd());
890901

891902
EBM_ASSERT(!IsOverflowTreeNodeSize(bHessian, cScores));
892903
const size_t cBytesPerTreeNode = GetTreeNodeSize(bHessian, cScores);
@@ -1040,21 +1051,19 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10401051
*pTotalGain = static_cast<double>(totalGain);
10411052

10421053
size_t cSlices = cSplitsMax - cSplitsRemaining + 1;
1043-
bool bExtraMissingCut = false;
10441054
if(nullptr != pMissingValueTreeNode) {
10451055
EBM_ASSERT(nullptr == pMissingBin);
10461056
++cSlices;
10471057
} else {
10481058
if(nullptr != pMissingBin && !bMissingIsolated) {
1049-
bExtraMissingCut = true;
10501059
++cSlices;
10511060
}
10521061
}
10531062
if(bNominal) {
10541063
cSlices = cBins;
10551064
}
10561065
const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
1057-
bExtraMissingCut,
1066+
bMissing,
10581067
bNominal,
10591068
flags,
10601069
regAlpha,

0 commit comments

Comments
 (0)