Skip to content

Commit 7cea7a3

Browse files
committed
create an explicit leaf cut whenever there is a missing value
1 parent f246e66 commit 7cea7a3

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

shared/libebm/PartitionOneDimensionalBoosting.cpp

+54-4
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ WARNING_DISABLE_UNINITIALIZED_LOCAL_POINTER
108108
// do not inline this. Not inlining it makes fewer versions that can be called from the more templated functions
109109
template<bool bHessian>
110110
static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
111+
bool bExtraMissingCut,
111112
const bool bNominal,
112113
const TermBoostFlags flags,
113114
const FloatCalc regAlpha,
@@ -131,6 +132,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
131132
EBM_ASSERT(2 <= cBins);
132133
EBM_ASSERT(cSlices <= cBins);
133134
EBM_ASSERT(!bNominal || cSlices == cBins);
135+
EBM_ASSERT(!bExtraMissingCut || !bNominal); // for Nominal we cut everywhere
134136

135137
ErrorEbm error;
136138

@@ -176,7 +178,7 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
176178
} else {
177179
pUpdateScore = aUpdateScore;
178180

179-
if(nullptr != pMissingValueTreeNode) {
181+
if(nullptr != pMissingValueTreeNode || bExtraMissingCut) {
180182
// always put a split on the missing bin
181183
*pSplit = 1;
182184
++pSplit;
@@ -237,6 +239,18 @@ static ErrorEbm Flatten(BoosterShell* const pBoosterShell,
237239
}
238240
}
239241

242+
if(bExtraMissingCut) {
243+
EBM_ASSERT(!bNominal); // for Nominal we cut everywhere
244+
if(TermBoostFlags_MissingLow & flags) {
245+
if(nullptr == pMissingBin) {
246+
pMissingBin = pTreeNode->GetBin();
247+
}
248+
} else {
249+
EBM_ASSERT(TermBoostFlags_MissingHigh & flags);
250+
pMissingBin = pTreeNode->GetBin();
251+
}
252+
}
253+
240254
EBM_ASSERT(apBins <= ppBinLast);
241255
EBM_ASSERT(ppBinLast < apBins + (cBins - (nullptr != pMissingValueTreeNode ? size_t{1} : size_t{0})));
242256

@@ -365,6 +379,8 @@ static int FindBestSplitGain(RandomDeterministic* const pRng,
365379
const FloatCalc regLambda,
366380
const FloatCalc deltaStepMax,
367381
const MonotoneDirection monotoneDirection,
382+
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* const pMissingBin,
383+
bool* pbMissingIsolated,
368384
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>** const ppMissingValueTreeNode) {
369385

370386
LOG_N(Trace_Verbose,
@@ -401,6 +417,9 @@ static int FindBestSplitGain(RandomDeterministic* const pRng,
401417
if(ppBinCur == ppBinLast) {
402418
// There is just one bin and therefore no splits
403419
pTreeNode->AFTER_RejectSplit();
420+
if(pMissingBin == *ppBinCur) {
421+
*pbMissingIsolated = true;
422+
}
404423
return 1;
405424
}
406425

@@ -822,10 +841,16 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
822841
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>** ppBin = apBins;
823842
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* pBin = aBins;
824843

844+
const Bin<FloatMain, UIntMain, true, true, bHessian, GetArrayScores(cCompilerScores)>* pMissingBin = nullptr;
845+
bool bMissingIsolated = false;
846+
825847
size_t cBinsAdjusted = cBins;
826848
const TreeNode<bHessian, GetArrayScores(cCompilerScores)>* pMissingValueTreeNode = nullptr;
827849
if(TermBoostFlags_MissingLow & flags) {
828850
if(bMissing) {
851+
if(!bNominal) {
852+
pMissingBin = pBin;
853+
}
829854
*ppBin = pBin;
830855
pBin = IndexBin(pBin, cBytesPerBin);
831856
++ppBin;
@@ -879,6 +904,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
879904
regLambda,
880905
deltaStepMax,
881906
monotoneDirection,
907+
pMissingBin,
908+
&bMissingIsolated,
882909
&pMissingValueTreeNode);
883910
size_t cSplitsRemaining = cSplitsMax;
884911
FloatCalc totalGain = 0;
@@ -952,6 +979,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
952979
regLambda,
953980
deltaStepMax,
954981
monotoneDirection,
982+
pMissingBin,
983+
&bMissingIsolated,
955984
&pMissingValueTreeNode);
956985
// if FindBestSplitGain returned -1 to indicate an
957986
// overflow ignore it here. We successfully made a root node split, so we might as well continue
@@ -976,6 +1005,8 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
9761005
regLambda,
9771006
deltaStepMax,
9781007
monotoneDirection,
1008+
pMissingBin,
1009+
&bMissingIsolated,
9791010
&pMissingValueTreeNode);
9801011
// if FindBestSplitGain returned -1 to indicate an
9811012
// overflow ignore it here. We successfully made a root node split, so we might as well continue
@@ -1007,9 +1038,23 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10071038
}
10081039
}
10091040
*pTotalGain = static_cast<double>(totalGain);
1010-
size_t cSlices =
1011-
bNominal ? cBins : cSplitsMax - cSplitsRemaining + 1 + (nullptr != pMissingValueTreeNode ? 1 : 0);
1012-
return Flatten<bHessian>(pBoosterShell,
1041+
1042+
size_t cSlices = cSplitsMax - cSplitsRemaining + 1;
1043+
bool bExtraMissingCut = false;
1044+
if(nullptr != pMissingValueTreeNode) {
1045+
EBM_ASSERT(nullptr == pMissingBin);
1046+
++cSlices;
1047+
} else {
1048+
if(nullptr != pMissingBin && !bMissingIsolated) {
1049+
bExtraMissingCut = true;
1050+
++cSlices;
1051+
}
1052+
}
1053+
if(bNominal) {
1054+
cSlices = cBins;
1055+
}
1056+
const ErrorEbm error = Flatten<bHessian>(pBoosterShell,
1057+
bExtraMissingCut,
10131058
bNominal,
10141059
flags,
10151060
regAlpha,
@@ -1024,6 +1069,11 @@ template<bool bHessian, size_t cCompilerScores> class PartitionOneDimensionalBoo
10241069
cBins
10251070
#endif // NDEBUG
10261071
);
1072+
1073+
EBM_ASSERT(!bMissing || 2 <= pBoosterShell->GetInnerTermUpdate()->GetCountSlices(iDimension));
1074+
EBM_ASSERT(!bMissing || *pBoosterShell->GetInnerTermUpdate()->GetSplitPointer(iDimension) == 1);
1075+
1076+
return error;
10271077
}
10281078
};
10291079

0 commit comments

Comments
 (0)