Skip to content

Commit bbaaacf

Browse files
committed
create internal versions of exp and log to ensure identical cross platform results
1 parent 86a8a9e commit bbaaacf

7 files changed

Lines changed: 117 additions & 1 deletion

File tree

python/interpret-core/interpret/utils/_native.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def flat_mean(self, vals, weights=None):
265265
msg = "vals and weights must have the same shape to call flat_mean."
266266
raise Exception(msg)
267267

268-
n_tensor_bins = math.prod(vals.shape)
268+
n_tensor_bins = prod(vals.shape)
269269

270270
mean_result = ct.c_double(np.nan)
271271

@@ -341,6 +341,18 @@ def safe_stddev(self, tensor, weights=None):
341341

342342
return stddev_result
343343

344+
def safe_exp(self, vals):
345+
self._unsafe.SafeExp(
346+
prod(vals.shape),
347+
Native._make_pointer(vals, np.float64, None),
348+
)
349+
350+
def safe_log(self, vals):
351+
self._unsafe.SafeLog(
352+
prod(vals.shape),
353+
Native._make_pointer(vals, np.float64, None),
354+
)
355+
344356
def create_rng(self, random_state):
345357
if random_state is None:
346358
return None # non-deterministic
@@ -1096,6 +1108,22 @@ def _initialize(self, is_debug):
10961108
]
10971109
self._unsafe.SafeStandardDeviation.restype = ct.c_int32
10981110

1111+
self._unsafe.SafeExp.argtypes = [
1112+
# int64_t count
1113+
ct.c_int64,
1114+
# double * inout
1115+
ct.c_void_p,
1116+
]
1117+
self._unsafe.SafeExp.restype = None
1118+
1119+
self._unsafe.SafeLog.argtypes = [
1120+
# int64_t count
1121+
ct.c_int64,
1122+
# double * inout
1123+
ct.c_void_p,
1124+
]
1125+
self._unsafe.SafeLog.restype = None
1126+
10991127
self._unsafe.MeasureRNG.argtypes = []
11001128
self._unsafe.MeasureRNG.restype = ct.c_int64
11011129

shared/libebm/bridge/bridge.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ INTERNAL_IMPORT_EXPORT_INCLUDE ErrorEbm CreateMetric_Cpu_64(
213213
// MetricWrapper * const pMetricWrapperOut,
214214
);
215215

216+
INTERNAL_IMPORT_EXPORT_INCLUDE void Exp_Cpu_64(const size_t c, double* const a);
217+
INTERNAL_IMPORT_EXPORT_INCLUDE void Log_Cpu_64(const size_t c, double* const a);
218+
216219
INTERNAL_IMPORT_EXPORT_INCLUDE double FinishMetricC(
217220
const ObjectiveWrapper* const pObjectiveWrapper, const double metricSum);
218221
INTERNAL_IMPORT_EXPORT_INCLUDE BoolEbm CheckTargetsC(

shared/libebm/compute/cpu_ebm/cpu_64.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,4 +519,26 @@ INTERNAL_IMPORT_EXPORT_BODY ErrorEbm CreateMetric_Cpu_64(
519519
return Error_UnexpectedInternal;
520520
}
521521

522+
INTERNAL_IMPORT_EXPORT_BODY void Exp_Cpu_64(const size_t c, double* const a) {
523+
double* p = a;
524+
double* aEnd = a + c;
525+
while(aEnd != p) {
526+
Cpu_64_Float val = Cpu_64_Float::Load(p);
527+
val = Exp(val);
528+
val.Store(p);
529+
p += Cpu_64_Float::k_cSIMDPack;
530+
}
531+
}
532+
533+
INTERNAL_IMPORT_EXPORT_BODY void Log_Cpu_64(const size_t c, double* const a) {
534+
double* p = a;
535+
double* aEnd = a + c;
536+
while(aEnd != p) {
537+
Cpu_64_Float val = Cpu_64_Float::Load(p);
538+
val = Log(val);
539+
val.Store(p);
540+
p += Cpu_64_Float::k_cSIMDPack;
541+
}
542+
}
543+
522544
} // namespace DEFINED_ZONE_NAME

shared/libebm/inc/libebm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ EBM_API_INCLUDE ErrorEbm EBM_CALLING_CONVENTION SafeMean(
322322
EBM_API_INCLUDE ErrorEbm EBM_CALLING_CONVENTION SafeStandardDeviation(
323323
IntEbm countBags, IntEbm countTensorBins, const double* vals, const double* weights, double* tensorOut);
324324

325+
EBM_API_INCLUDE void EBM_CALLING_CONVENTION SafeExp(IntEbm count, double* inout);
326+
EBM_API_INCLUDE void EBM_CALLING_CONVENTION SafeLog(IntEbm count, double* inout);
327+
325328
EBM_API_INCLUDE IntEbm EBM_CALLING_CONVENTION MeasureRNG(void);
326329
EBM_API_INCLUDE void EBM_CALLING_CONVENTION InitRNG(SeedEbm seed, void* rngOut);
327330
EBM_API_INCLUDE void EBM_CALLING_CONVENTION CopyRNG(void* rng, void* rngOut);

shared/libebm/interpretable_numerics.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,62 @@ EBM_API_BODY ErrorEbm EBM_CALLING_CONVENTION SafeStandardDeviation(
17721772
return Error_None;
17731773
}
17741774

1775+
EBM_API_BODY void EBM_CALLING_CONVENTION SafeExp(IntEbm count, double* inout) {
1776+
if(count <= IntEbm{0}) {
1777+
if(count < IntEbm{0}) {
1778+
LOG_0(Trace_Error, "ERROR SafeExp count < IntEbm{0}");
1779+
}
1780+
return;
1781+
}
1782+
if(IsConvertError<size_t>(count)) {
1783+
LOG_0(Trace_Error, "ERROR SafeExp IsConvertError<size_t>(count)");
1784+
return;
1785+
}
1786+
const size_t c = static_cast<size_t>(count);
1787+
1788+
if(IsMultiplyError(sizeof(*inout), c)) {
1789+
LOG_0(Trace_Error, "ERROR SafeExp IsMultiplyError(sizeof(*inout), c)");
1790+
return;
1791+
}
1792+
1793+
if(nullptr == inout) {
1794+
LOG_0(Trace_Error, "ERROR SafeExp nullptr == inout");
1795+
return;
1796+
}
1797+
1798+
Exp_Cpu_64(c, inout);
1799+
1800+
// TODO: add a SIMD version here
1801+
}
1802+
1803+
EBM_API_BODY void EBM_CALLING_CONVENTION SafeLog(IntEbm count, double* inout) {
1804+
if(count <= IntEbm{0}) {
1805+
if(count < IntEbm{0}) {
1806+
LOG_0(Trace_Error, "ERROR SafeLog count < IntEbm{0}");
1807+
}
1808+
return;
1809+
}
1810+
if(IsConvertError<size_t>(count)) {
1811+
LOG_0(Trace_Error, "ERROR SafeLog IsConvertError<size_t>(count)");
1812+
return;
1813+
}
1814+
const size_t c = static_cast<size_t>(count);
1815+
1816+
if(IsMultiplyError(sizeof(*inout), c)) {
1817+
LOG_0(Trace_Error, "ERROR SafeLog IsMultiplyError(sizeof(*inout), c)");
1818+
return;
1819+
}
1820+
1821+
if(nullptr == inout) {
1822+
LOG_0(Trace_Error, "ERROR SafeLog nullptr == inout");
1823+
return;
1824+
}
1825+
1826+
Log_Cpu_64(c, inout);
1827+
1828+
// TODO: add a SIMD version here
1829+
}
1830+
17751831
// we don't care if an extra log message is outputted due to the non-atomic nature of the decrement to this value
17761832
static int g_cLogEnterGetHistogramCutCount = 25;
17771833
static int g_cLogExitGetHistogramCutCount = 25;

shared/libebm/libebm_exports.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ EXPORTS
88
SafeSum
99
SafeMean
1010
SafeStandardDeviation
11+
SafeExp
12+
SafeLog
1113
MeasureRNG
1214
InitRNG
1315
CopyRNG

shared/libebm/libebm_exports.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
SafeSum;
88
SafeMean;
99
SafeStandardDeviation;
10+
SafeExp;
11+
SafeLog;
1012
MeasureRNG;
1113
InitRNG;
1214
CopyRNG;

0 commit comments

Comments
 (0)