Skip to content

Commit ce5ec80

Browse files
committed
add SafeSum C function
1 parent f63729d commit ce5ec80

7 files changed

Lines changed: 203 additions & 3 deletions

File tree

python/interpret-core/interpret/utils/_native.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import math
1212
from contextlib import AbstractContextManager
13+
from math import prod
1314

1415
import numpy as np
1516

@@ -239,6 +240,25 @@ def clean_float(self, val):
239240
)
240241
return val_array[0]
241242

243+
def safe_sum(self, in_tensor, out_tensor, axis):
244+
shape = in_tensor.shape
245+
n_distant = prod(shape[:axis])
246+
n_close = prod(shape[axis + 1 :])
247+
248+
if prod(out_tensor.shape) != n_distant * n_close:
249+
msg = f"in {in_tensor.shape} and out {out_tensor.shape} tensors must have a reducible shape along axis {axis}."
250+
raise Exception(msg)
251+
252+
return_code = self._unsafe.SafeSum(
253+
n_distant,
254+
shape[axis],
255+
n_close,
256+
Native._make_pointer(in_tensor, np.float64, None, False),
257+
Native._make_pointer(out_tensor, np.float64, None, False),
258+
)
259+
if return_code: # pragma: no cover
260+
raise Native._get_native_exception(return_code, "SafeSum")
261+
242262
def flat_mean(self, vals, weights=None):
243263
if weights is not None:
244264
if vals.shape != weights.shape:
@@ -1034,6 +1054,20 @@ def _initialize(self, is_debug):
10341054
]
10351055
self._unsafe.CleanFloats.restype = None
10361056

1057+
self._unsafe.SafeSum.argtypes = [
1058+
# int64_t countDistant
1059+
ct.c_int64,
1060+
# int64_t countAxis
1061+
ct.c_int64,
1062+
# int64_t countClose
1063+
ct.c_int64,
1064+
# double * in
1065+
ct.c_void_p,
1066+
# double * out
1067+
ct.c_void_p,
1068+
]
1069+
self._unsafe.SafeSum.restype = ct.c_int32
1070+
10371071
self._unsafe.SafeMean.argtypes = [
10381072
# int64_t countBags
10391073
ct.c_int64,

shared/libebm/inc/libebm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ EBM_API_INCLUDE void EBM_CALLING_CONVENTION SetTraceLevel(TraceEbm traceLevel);
314314
EBM_API_INCLUDE const char* EBM_CALLING_CONVENTION GetTraceLevelString(TraceEbm traceLevel);
315315

316316
EBM_API_INCLUDE void EBM_CALLING_CONVENTION CleanFloats(IntEbm count, double* valsInOut);
317+
318+
EBM_API_INCLUDE ErrorEbm EBM_CALLING_CONVENTION SafeSum(
319+
IntEbm countDistant, IntEbm countAxis, IntEbm countClose, const double* in, double* out);
317320
EBM_API_INCLUDE ErrorEbm EBM_CALLING_CONVENTION SafeMean(
318321
IntEbm countBags, IntEbm countTensorBins, const double* vals, const double* weights, double* tensorOut);
319322
EBM_API_INCLUDE ErrorEbm EBM_CALLING_CONVENTION SafeStandardDeviation(

shared/libebm/interpretable_numerics.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,123 @@ static double Mean(const size_t cSamples,
14801480
return mean;
14811481
}
14821482

1483+
// we don't care if an extra log message is outputted due to the non-atomic nature of the decrement to this value
1484+
static int g_cLogEnterSafeSumCount = 25;
1485+
static int g_cLogExitSafeSumCount = 25;
1486+
1487+
EBM_API_BODY ErrorEbm EBM_CALLING_CONVENTION SafeSum(
1488+
IntEbm countDistant, IntEbm countAxis, IntEbm countClose, const double* in, double* out) {
1489+
1490+
LOG_COUNTED_N(&g_cLogEnterSafeSumCount,
1491+
Trace_Info,
1492+
Trace_Verbose,
1493+
"Entered SafeSum: "
1494+
"countDistant=%" IntEbmPrintf ", "
1495+
"countAxis=%" IntEbmPrintf ", "
1496+
"countClose=%" IntEbmPrintf ", "
1497+
"in=%p, "
1498+
"out=%p",
1499+
countDistant,
1500+
countAxis,
1501+
countClose,
1502+
static_cast<const void*>(in),
1503+
static_cast<const void*>(out));
1504+
1505+
if(nullptr == in) {
1506+
LOG_0(Trace_Error, "ERROR SafeSum nullptr == in");
1507+
return Error_IllegalParamVal;
1508+
}
1509+
1510+
if(nullptr == out) {
1511+
LOG_0(Trace_Error, "ERROR SafeSum nullptr == out");
1512+
return Error_IllegalParamVal;
1513+
}
1514+
1515+
if(countDistant <= IntEbm{0}) {
1516+
if(countDistant < IntEbm{0}) {
1517+
LOG_0(Trace_Error, "ERROR SafeSum countDistant < IntEbm{0}");
1518+
return Error_IllegalParamVal;
1519+
}
1520+
return Error_None;
1521+
}
1522+
if(IsConvertError<size_t>(countDistant)) {
1523+
LOG_0(Trace_Error, "ERROR SafeSum IsConvertError<size_t>(countDistant)");
1524+
return Error_IllegalParamVal;
1525+
}
1526+
const size_t cDistant = static_cast<size_t>(countDistant);
1527+
1528+
if(countClose <= IntEbm{0}) {
1529+
if(countClose < IntEbm{0}) {
1530+
LOG_0(Trace_Error, "ERROR SafeSum countClose < IntEbm{0}");
1531+
return Error_IllegalParamVal;
1532+
}
1533+
return Error_None;
1534+
}
1535+
if(IsConvertError<size_t>(countClose)) {
1536+
LOG_0(Trace_Error, "ERROR SafeSum IsConvertError<size_t>(countClose)");
1537+
return Error_IllegalParamVal;
1538+
}
1539+
const size_t cClose = static_cast<size_t>(countClose);
1540+
1541+
if(IsMultiplyError(sizeof(double), cClose)) {
1542+
LOG_0(Trace_Error, "ERROR SafeSum IsMultiplyError(sizeof(double), cClose)");
1543+
return Error_IllegalParamVal;
1544+
}
1545+
const size_t cCloseBytes = sizeof(double) * cClose;
1546+
1547+
if(IsMultiplyError(cCloseBytes, cDistant)) {
1548+
LOG_0(Trace_Error, "ERROR SafeSum IsMultiplyError(cCloseBytes, cDistant)");
1549+
return Error_IllegalParamVal;
1550+
}
1551+
const size_t cNonAxisBytes = cCloseBytes * cDistant;
1552+
1553+
if(countAxis <= IntEbm{1}) {
1554+
if(IntEbm{1} == countAxis) {
1555+
memcpy(out, in, cNonAxisBytes);
1556+
} else if(countAxis < IntEbm{0}) {
1557+
LOG_0(Trace_Error, "ERROR SafeSum countAxis < IntEbm{0}");
1558+
return Error_IllegalParamVal;
1559+
}
1560+
return Error_None;
1561+
}
1562+
if(IsConvertError<size_t>(countAxis)) {
1563+
LOG_0(Trace_Error, "ERROR SafeSum IsConvertError<size_t>(countAxis)");
1564+
return Error_IllegalParamVal;
1565+
}
1566+
const size_t cAxis = static_cast<size_t>(countAxis);
1567+
1568+
if(IsMultiplyError(cNonAxisBytes, cAxis)) {
1569+
LOG_0(Trace_Error, "ERROR SafeSum IsMultiplyError(cNonAxisBytes, cAxis)");
1570+
return Error_IllegalParamVal;
1571+
}
1572+
1573+
const double* const pOutEnd = IndexByte(out, cNonAxisBytes);
1574+
const size_t cAxisBytes = cCloseBytes * cAxis;
1575+
const size_t cNext = sizeof(*in) - cAxisBytes;
1576+
const size_t cNextGrouping = cAxisBytes - cCloseBytes;
1577+
size_t i = 0;
1578+
do {
1579+
const double* const pCloseEnd = IndexByte(out, cCloseBytes);
1580+
do {
1581+
const size_t iEnd = i + cAxisBytes;
1582+
double sum = *IndexByte(in, i);
1583+
i += cCloseBytes;
1584+
do {
1585+
sum += *IndexByte(in, i);
1586+
i += cCloseBytes;
1587+
} while(iEnd != i);
1588+
i += cNext;
1589+
*out = sum;
1590+
++out;
1591+
} while(pCloseEnd != out);
1592+
i += cNextGrouping;
1593+
} while(pOutEnd != out);
1594+
1595+
LOG_COUNTED_0(&g_cLogExitSafeSumCount, Trace_Info, Trace_Verbose, "Exited SafeSum");
1596+
1597+
return Error_None;
1598+
}
1599+
14831600
// we don't care if an extra log message is outputted due to the non-atomic nature of the decrement to this value
14841601
static int g_cLogEnterSafeMeanCount = 25;
14851602
static int g_cLogExitSafeMeanCount = 25;

shared/libebm/libebm_exports.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ EXPORTS
55
SetTraceLevel
66
GetTraceLevelString
77
CleanFloats
8+
SafeSum
89
SafeMean
910
SafeStandardDeviation
1011
MeasureRNG

shared/libebm/libebm_exports.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
SetTraceLevel;
55
GetTraceLevelString;
66
CleanFloats;
7+
SafeSum;
78
SafeMean;
89
SafeStandardDeviation;
910
MeasureRNG;

shared/libebm/tests/SuggestGraphBoundsTest.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,50 @@ TEST_CASE("SuggestGraphBounds, 2 cuts, overflow diff") {
302302
CHECK(std::numeric_limits<double>::infinity() == highGraphBound);
303303
}
304304

305+
TEST_CASE("SafeSum, 4x3x2") {
306+
const double in[]{1.0,
307+
2.0,
308+
3.0,
309+
4.0,
310+
5.0,
311+
6.0,
312+
7.0,
313+
8.0,
314+
9.0,
315+
10.0,
316+
11.0,
317+
12.0,
318+
13.0,
319+
14.0,
320+
15.0,
321+
16.0,
322+
17.0,
323+
18.0,
324+
19.0,
325+
20.0,
326+
21.0,
327+
22.0,
328+
23.0,
329+
24.0};
330+
331+
double out[sizeof(in) / sizeof(in[0]) / 3];
332+
333+
const ErrorEbm error = SafeSum(4, 3, 2, in, out);
334+
CHECK(Error_None == error);
335+
336+
CHECK(out[0] == 1.0 + 3.0 + 5.0);
337+
CHECK(out[1] == 2.0 + 4.0 + 6.0);
338+
339+
CHECK(out[2] == 7.0 + 9.0 + 11.0);
340+
CHECK(out[3] == 8.0 + 10.0 + 12.0);
341+
342+
CHECK(out[4] == 13.0 + 15.0 + 17.0);
343+
CHECK(out[5] == 14.0 + 16.0 + 18.0);
344+
345+
CHECK(out[6] == 19.0 + 21.0 + 23.0);
346+
CHECK(out[7] == 20.0 + 22.0 + 24.0);
347+
}
348+
305349
TEST_CASE("SafeMean, 4 values") {
306350
double vals[]{1.0, 2.5, 10, 100};
307351
const size_t cVals = sizeof(vals) / sizeof(vals[0]);

shared/libebm/tests/boosting_unusual_inputs.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,7 +2363,7 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
23632363
}
23642364
} else {
23652365
double modelSum = 0.0;
2366-
for(IntEbm iTerm = 0; iTerm < static_cast<IntEbm>(terms.size()); ++iTerm) {
2366+
for(size_t iTerm = 0; iTerm < terms.size(); ++iTerm) {
23672367
const auto term = terms[iTerm];
23682368
size_t cScores = 3 <= classesCount ? static_cast<size_t>(classesCount) : size_t{1};
23692369
for(size_t iDim = 0; iDim < term.size(); ++iDim) {
@@ -2372,12 +2372,12 @@ static double RandomizedTesting(const AccelerationFlags acceleration) {
23722372
}
23732373

23742374
std::vector<double> model(cScores);
2375-
test.GetBestTermScoresRaw(iTerm, &model[0]);
2375+
test.GetBestTermScoresRaw(static_cast<IntEbm>(iTerm), &model[0]);
23762376
for(size_t iScore = 0; iScore < cScores; ++iScore) {
23772377
modelSum += model[iScore];
23782378
}
23792379

2380-
test.GetCurrentTermScoresRaw(iTerm, &model[0]);
2380+
test.GetCurrentTermScoresRaw(static_cast<IntEbm>(iTerm), &model[0]);
23812381
for(size_t iScore = 0; iScore < cScores; ++iScore) {
23822382
modelSum += model[iScore];
23832383
}

0 commit comments

Comments
 (0)