Skip to content

Commit 348a494

Browse files
committed
Support binary calibration tests, improved calibration plots
1 parent 22be061 commit 348a494

File tree

3 files changed

+208
-10
lines changed

3 files changed

+208
-10
lines changed

.Rprofile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
source("renv/activate.R")
2-
source("~/.Rprofile")
2+
if(file.exists("~/.Rprofile")) source("~/.Rprofile")
33
# Allows to change how all vignettes are run at once (especially to test rstan)
44
options("SBC.vignettes_cmdstanr" = TRUE)
55

R/binary-calibration-tests.R

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
brier_score <- function(x, y) {
2+
sum((x-y)^2)
3+
}
4+
5+
brier_resampling_p <- function(x, y, B = 10000) {
6+
actual_brier <- brier_score(x, y)
7+
brier_null <- replicate(B, {
8+
yrep <- rbinom(length(x), size = 1, prob = x)
9+
brier_score(x, yrep)
10+
})
11+
max(mean(actual_brier <= brier_null), 0.5/B)
12+
}
13+
14+
brier_resampling_test <- function(x, y, alpha = 0.05, B = 10000) {
15+
dname <- paste0("x = ", deparse1(substitute(x)), ", y = ", deparse1(substitute(y)))
16+
17+
actual_brier <- brier_score(x, y)
18+
brier_null <- replicate(B, {
19+
yrep <- rbinom(length(x), size = 1, prob = x)
20+
brier_score(x, yrep)
21+
})
22+
23+
p <- max(mean(actual_brier <= brier_null), 0.5/B)
24+
25+
param <- quantile(brier_null, probs = 1 - alpha)
26+
names(param) <- paste0(scales::percent(1 - alpha), " rejection limit")
27+
28+
structure(list(
29+
method = paste0("Bootstrapped binary Brier score test (using ", B, " samples)"),
30+
data.name = dname,
31+
p.value = p,
32+
estimate = c("Brier score" = actual_brier),
33+
parameter = param
34+
),
35+
class = "htest")
36+
}
37+
38+
binary_miscalibration <- function(x,y) {
39+
require_package_version("monotone", "0.1.2", "miscalibration computations")
40+
ord <- order(x, -y)
41+
x <- x[ord]
42+
y <- y[ord]
43+
#CEP_pav <- stats::isoreg(y)$yf
44+
CEP_pav <- monotone::monotone(y)
45+
#Using brier score
46+
Sc <- mean((CEP_pav - y)^2)
47+
mean((x - y) ^2) - Sc
48+
}
49+
50+
# Faster reimplementation from https://www.pnas.org/doi/full/10.1073/pnas.2016191118#sec-4
51+
# and the reliabilitydiag package
52+
miscalibration_resampling_nulldist <- function(x,y, B = 1000) {
53+
replicate(B, {
54+
yrep <- rbinom(length(x), size = 1, prob = x)
55+
binary_miscalibration(x, yrep)
56+
})
57+
}
58+
59+
miscalibration_resampling_p <- function(x,y, B = 10000) {
60+
actual_miscalibration <- binary_miscalibration(x,y)
61+
misc_null <- miscalibration_resampling_nulldist(x, y, B)
62+
max(mean(actual_miscalibration <= misc_null), 0.5/B)
63+
}
64+
65+
#' @export
66+
miscalibration_resampling_test <- function(x, y, alpha = 0.05, B = 10000) {
67+
dname <- paste0("x = ", deparse1(substitute(x)), ", y = ", deparse1(substitute(y)))
68+
69+
actual_miscalibration <- binary_miscalibration(x,y)
70+
misc_null <- miscalibration_resampling_nulldist(x, y, B)
71+
p <- max(mean(actual_miscalibration <= misc_null), 0.5/B)
72+
73+
param <- quantile(misc_null, probs = 1 - alpha)
74+
names(param) <- paste0(scales::percent(1 - alpha), " rejection limit")
75+
76+
structure(list(
77+
method = paste0("Bootstrapped binary miscalibration test (using ", B, " samples)"),
78+
data.name = dname,
79+
p.value = p,
80+
estimate = c("miscalibration" = actual_miscalibration),
81+
parameter = param
82+
),
83+
class = "htest")
84+
}
85+
86+
gaffke_m <- function(probs, B = 10000) {
87+
require_package_version("MCMCpack", "1.0.0", "the Gaffke test")
88+
u_diff <- MCMCpack::rdirichlet(B, alpha = rep(1, length(probs) + 1))
89+
90+
probs_sort <- sort(probs)
91+
z_upr <- c(probs_sort, 1)
92+
m_matrix_upr <- sweep(u_diff, MARGIN = 2, STATS = z_upr, FUN = "*")
93+
m_upr <- rowSums(m_matrix_upr)
94+
95+
#stopifnot(identical(sort(1 - probs), rev(1 - probs_sort)))
96+
z_lwr <- c(rev(1 - probs_sort), 1)
97+
m_matrix_lwr <- sweep(u_diff, MARGIN = 2, STATS = z_lwr, FUN = "*")
98+
m_lwr <- rowSums(m_matrix_lwr)
99+
100+
list(lwr = m_lwr, upr = m_upr)
101+
}
102+
103+
gaffke_ci_from_m <- function(m, alpha = 0.05) {
104+
m_lwr <- m$lwr
105+
m_upr <- m$upr
106+
107+
as.numeric(c(
108+
1 - quantile(m_lwr, probs = 1 - alpha / 2),
109+
quantile(m_upr, probs = 1 - alpha / 2)
110+
))
111+
}
112+
113+
gaffke_ci <- function(probs, B = 10000, alpha = 0.05) {
114+
m <- gaffke_m(probs, B, alpha)
115+
gaffke_ci_from_m(m, alpha)
116+
}
117+
118+
gaffke_p_from_m <- function(m, mu, B, alternative = c("two.sided", "less", "greater")) {
119+
alternative <- match.arg(alternative)
120+
121+
m_lwr <- m$lwr
122+
m_upr <- m$upr
123+
124+
prob_low <- mean(1-m_lwr <= mu)
125+
if(prob_low == 0) {
126+
prob_low <- 0.5/B
127+
}
128+
prob_high <- mean(m_upr >= mu)
129+
if(prob_high == 0) {
130+
prob_high <- 0.5/B
131+
}
132+
if(alternative == "two.sided") {
133+
return(min(prob_low, prob_high, 0.5) * 2)
134+
} else if(alternative == "less") {
135+
return(prob_high)
136+
} else if(alternative == "greater") {
137+
return(prob_low)
138+
} else {
139+
stop("Invalid alternative")
140+
}
141+
}
142+
143+
gaffke_p <- function(probs, mu = 0.5, alpha = 0.05, B = 10000, alternative = c("two.sided", "less", "greater")) {
144+
alternative <- match.arg(alternative)
145+
146+
m <- gaffke_m(probs, B, alpha)
147+
gaffke_p_from_m(m, mu, B, alternative)
148+
}
149+
150+
#' Non-parametric test for the mean of a bounded variable.
151+
#' @export
152+
gaffke_test <- function(x, mu = 0.5, alpha = 0.05, lb = 0, ub = 1, B = 10000, alternative = c("two.sided", "less", "greater")) {
153+
dname <- deparse1(substitute(x))
154+
alternative <- match.arg(alternative)
155+
156+
stopifnot(length(lb) == 1)
157+
stopifnot(length(ub) == 1)
158+
stopifnot(all(x >= lb))
159+
stopifnot(all(x <= ub))
160+
stopifnot(length(B) == 1 && B > 1)
161+
stopifnot(0 < alpha && alpha < 1)
162+
stopifnot(mu >= lb && mu <= ub)
163+
164+
x_scaled <- (x - lb) / (ub - lb)
165+
mu_scaled <- (mu - lb) / (ub - lb)
166+
m <- gaffke_m(x_scaled, B = B)
167+
p <- gaffke_p_from_m(m, mu_scaled, alternative = alternative)
168+
ci <- gaffke_ci_from_m(m, alpha = alpha)
169+
attr(ci, "conf.level") <- 1 - alpha
170+
171+
structure(list(
172+
method = paste0("Gaffke's test for the mean of a bounded variable (using ", B, " samples)"),
173+
data.name = dname,
174+
p.value = p,
175+
alternative = alternative,
176+
null.value = c("mean" = mu),
177+
conf.int = ci,
178+
estimate = c("mean" = mean(x)),
179+
parameter = c("lower bound" = lb, "upper bound" = ub)
180+
),
181+
class = "htest")
182+
}

R/binary-calibration.R

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ binary_probabilities_from_stats <- function(stats) {
1919
}
2020

2121
#' @export
22-
binary_calibration_from_stats <- function(stats, type = "isotonic", ...) {
22+
binary_calibration_from_stats <- function(stats, type = c("reliabilitydiag", "calibrationband"), ...) {
2323
stats <- binary_probabilities_from_stats(stats)
2424

2525
stats_grouped <- dplyr::group_by(stats, variable)
@@ -29,8 +29,10 @@ binary_calibration_from_stats <- function(stats, type = "isotonic", ...) {
2929
}
3030

3131
#' @export
32-
binary_calibration_base <- function(prob, outcome, type = "isotonic", ...) {
33-
stopifnot(is.numeric(prob) && is.numeric(outcome))
32+
binary_calibration_base <- function(prob, outcome, uncertainty_prob = 0.95, type = c("reliabilitydiag", "calibrationband"), ...) {
33+
stopifnot(is.numeric(prob))
34+
stopifnot((is.numeric(outcome) || is.logical(outcome) || is.integer(outcome)))
35+
outcome <- as.numeric(outcome)
3436
stopifnot(all(outcome %in% c(0,1)))
3537
stopifnot(all(prob >=0 & prob <= 1))
3638
stopifnot(length(prob) == length(outcome))
@@ -41,8 +43,22 @@ binary_calibration_base <- function(prob, outcome, type = "isotonic", ...) {
4143
outcome <- outcome[!na_indices]
4244

4345
type <- match.arg(type)
44-
if(type == "isotonic") {
45-
require_package_version("calibrationband", "0.2", "to compute binary calibration with the type 'isotonic'.")
46+
if(type == "reliabilitydiag") {
47+
require_package_version("reliabilitydiag", "0.2.1", "to compute binary calibration with the type 'reliabilitydiag'.")
48+
rel_diag <- reliabilitydiag::reliabilitydiag(
49+
x = prob,
50+
y = outcome,
51+
region.level = uncertainty_prob,
52+
...
53+
)
54+
res <- data.frame(prob = rel_diag$x$regions$x, low = rel_diag$x$regions$lower, high = rel_diag$x$regions$upper)
55+
res$estimate <- approx(x = c(rel_diag$x$bins$x_min, rel_diag$x$bins$x_max),
56+
y = rep(rel_diag$x$bins$CEP_pav, times = 2),
57+
xout = res$prob)$y
58+
59+
return(res)
60+
} else if(type == "calibrationband") {
61+
require_package_version("calibrationband", "0.2", "to compute binary calibration with the type 'calibrationband'.")
4662
# Need to remove extreme indices because they cause crashes in the package
4763
extreme_indices <- prob < 1e-10 | prob > 1 - 1e-10
4864
extreme_indices_mismatch <- extreme_indices & round(prob) != outcome
@@ -56,7 +72,7 @@ binary_calibration_base <- function(prob, outcome, type = "isotonic", ...) {
5672
# Avoiding https://github.com/marius-cp/calibrationband/issues/1
5773
prob <- round(prob, digits = 7)
5874

59-
bands <- calibrationband::calibration_bands(prob, outcome, ...)
75+
bands <- calibrationband::calibration_bands(prob, outcome, alpha = 1 - uncertainty_prob, ...)
6076

6177
res <- dplyr::transmute(bands$bands, prob = x, low = lwr, high = upr)
6278

@@ -77,7 +93,7 @@ binary_calibration_base <- function(prob, outcome, type = "isotonic", ...) {
7793

7894

7995
#' @export
80-
plot_binary_calibration_diff <- function(stats, type = "isotonic", ...) {
96+
plot_binary_calibration_diff <- function(stats, type = c("reliabilitydiag", "calibrationband"), ...) {
8197
calib_df <- binary_calibration_from_stats(stats, type = type, ...)
8298

8399
ggplot(calib_df, aes(x = prob, ymin = low - prob, ymax = high - prob, y = estimate - prob)) +
@@ -87,11 +103,11 @@ plot_binary_calibration_diff <- function(stats, type = "isotonic", ...) {
87103
}
88104

89105
#' @export
90-
plot_binary_calibration <- function(stats, type = "isotonic", ...) {
106+
plot_binary_calibration <- function(stats, type = c("reliabilitydiag", "calibrationband"), ...) {
91107
calib_df <- binary_calibration_from_stats(stats, type = type, ...)
92108

93109
ggplot(calib_df, aes(x = prob, ymin = low, ymax = high, y = estimate)) +
94110
geom_segment(x = 0, y = 0, xend = 1, yend = 1, color = "skyblue1", size = 2) +
95111
geom_ribbon(fill = "black", alpha = 0.33) +
96-
geom_line() + facet_wrap(~variable)
112+
geom_line() + facet_wrap(~variable) + coord_fixed()
97113
}

0 commit comments

Comments
 (0)