Skip to content

Commit 458d69c

Browse files
authored
Add target clipping and update init point for platt transform (#272)
1 parent e4c824d commit 458d69c

File tree

3 files changed

+42
-23
lines changed

3 files changed

+42
-23
lines changed

pecos/core/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,25 +2058,37 @@ def link_calibrator_methods(self):
20582058
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
20592059
)
20602060

2061-
def fit_platt_transform(self, logits, tgt_prob):
2061+
def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
20622062
"""Python to C/C++ interface for platt transfrom fit.
20632063
20642064
Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf
20652065
20662066
Args:
20672067
logits (ndarray): 1-d array of logit with length N.
2068-
tgt_prob (ndarray): 1-d array of target probability scores within [0, 1] with length N.
2068+
targets (ndarray): 1-d array of target probability scores within [0, 1] with length N.
2069+
clip_tgt_prob (bool): whether to clip the target probability to
2070+
[1/(prior0 + 2), 1 - 1/(prior1 + 2)]
2071+
where prior1 = sum(targets), prior0 = N - prior1
20692072
Returns:
20702073
A, B: coefficients for Platt's scale.
20712074
"""
20722075
assert isinstance(logits, np.ndarray)
2073-
assert isinstance(tgt_prob, np.ndarray)
2074-
assert len(logits) == len(tgt_prob)
2075-
assert logits.dtype == tgt_prob.dtype
2076+
assert isinstance(targets, np.ndarray)
2077+
assert len(logits) == len(targets)
2078+
assert logits.dtype == targets.dtype
20762079

2077-
if tgt_prob.min() < 0 or tgt_prob.max() > 1.0:
2080+
if targets.min() < 0 or targets.max() > 1.0:
20782081
raise ValueError("Target probability out of bound!")
20792082

2083+
min_prob, max_prob = 0.0, 1.0
2084+
if clip_tgt_prob:
2085+
prior1 = np.sum(targets)
2086+
prior0 = len(targets) - prior1
2087+
min_prob = 1.0 / (prior0 + 2.0)
2088+
max_prob = (prior1 + 1.0) / (prior1 + 2.0)
2089+
2090+
tgt_prob = np.clip(targets, min_prob, max_prob)
2091+
20802092
AB = np.array([0, 0], dtype=np.float64)
20812093

20822094
if tgt_prob.dtype == np.float32:

pecos/core/utils/newton.hpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -280,23 +280,29 @@ namespace pecos {
280280

281281
template <typename value_type>
282282
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
283-
// define the return code
284-
enum {
285-
SUCCESS=0,
286-
LINE_SEARCH_FAIL=1,
287-
MAX_ITER_REACHED=2,
288-
};
283+
// define the return code
284+
enum {
285+
SUCCESS=0,
286+
LINE_SEARCH_FAIL=1,
287+
MAX_ITER_REACHED=2,
288+
};
289289

290290
// hyper parameters
291291
int max_iter = 100; // Maximal number of iterations
292292
double min_step = 1e-10; // Minimal step taken in line search
293293
double sigma = 1e-12; // For numerically strict PD of Hessian
294-
double eps = 1e-6;
294+
double eps = 1e-5;
295+
296+
// calculate prior of B
297+
double prior1 = 0;
298+
for (size_t i = 0; i < num_samples; i++) {
299+
prior1 += tgt_probs[i];
300+
}
301+
double prior0 = double(num_samples) - prior1;
295302

296-
int iter;
297303

298304
// Initial Point and Initial Fun Value
299-
A = 0.0; B = 1.0;
305+
A = 0.0; B = log((prior0 + 1.0) / (prior1 + 1.0));
300306
double fval = 0.0;
301307

302308
for (size_t i = 0; i < num_samples; i++) {
@@ -307,17 +313,18 @@ namespace pecos {
307313
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
308314
}
309315
}
316+
int iter;
310317
for (iter = 0; iter < max_iter; iter++) {
311318
// Update Gradient and Hessian (use H' = H + sigma I)
312319
double h11 = sigma;
313320
double h22 = sigma; // numerically ensures strict PD
314321
double h21 = 0.0;
315-
double g1 = 0.0;
316-
double g2 = 0.0;
322+
double g1 = A * sigma;
323+
double g2 = B * sigma;
317324

318325
for (size_t i = 0; i < num_samples; i++) {
319326
double fApB = logits[i] * A + B;
320-
double p = 0, q = 0;
327+
double p = 0, q = 0;
321328
if (fApB >= 0) {
322329
p = exp(-fApB) / (1.0 + exp(-fApB));
323330
q = 1.0 / (1.0 + exp(-fApB));
@@ -376,15 +383,15 @@ namespace pecos {
376383

377384
if (stepsize < min_step) {
378385
printf("WARNING: fit_platt_transform: Line search fails\n");
379-
return LINE_SEARCH_FAIL;
386+
return LINE_SEARCH_FAIL;
380387
}
381388
}
382389

383390
if (iter >= max_iter) {
384391
printf("WARNING: fit_platt_transform: Reaching maximal iterations\n");
385-
return MAX_ITER_REACHED;
392+
return MAX_ITER_REACHED;
386393
}
387-
return SUCCESS;
394+
return SUCCESS;
388395
}
389396
} // namespace pecos
390397
#endif

test/pecos/core/test_clib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ def test_platt_scale():
7878

7979
orig = np.arange(-15, 15, 1, dtype=np.float32)
8080
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float32)
81-
At, Bt = clib.fit_platt_transform(orig, tgt)
81+
At, Bt = clib.fit_platt_transform(orig, tgt, clip_tgt_prob=False)
8282
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
8383
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"
8484

8585
orig = np.arange(-15, 15, 1, dtype=np.float64)
8686
tgt = np.array([1.0 / (1 + np.exp(A * t + B)) for t in orig], dtype=np.float64)
87-
At, Bt = clib.fit_platt_transform(orig, tgt)
87+
At, Bt = clib.fit_platt_transform(orig, tgt, clip_tgt_prob=False)
8888
assert B == approx(Bt, abs=1e-6), f"Platt_scale B error: {B} != {Bt}"
8989
assert A == approx(At, abs=1e-6), f"Platt_scale A error: {A} != {At}"

0 commit comments

Comments
 (0)