Skip to content

Commit 83eefd5

Browse files
committed
Allowing multiple models for bridgesampling
1 parent c317542 commit 83eefd5

23 files changed

+710
-193
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ S3method(SBC_fit,SBC_backend_mock_rng)
3838
S3method(SBC_fit,SBC_backend_rjags)
3939
S3method(SBC_fit,SBC_backend_rstan_optimizing)
4040
S3method(SBC_fit,SBC_backend_rstan_sample)
41+
S3method(SBC_fit_specific_dquants,SBC_fit_bridgesampling)
42+
S3method(SBC_fit_specific_dquants,default)
4143
S3method(SBC_fit_to_BFBayesFactor,SBC_fit_lmBF)
4244
S3method(SBC_fit_to_bridge_sampler,SBC_backend_brms)
4345
S3method(SBC_fit_to_bridge_sampler,SBC_backend_cached)
@@ -131,6 +133,7 @@ export(SBC_example_backend)
131133
export(SBC_example_generator)
132134
export(SBC_example_results)
133135
export(SBC_fit)
136+
export(SBC_fit_specific_dquants)
134137
export(SBC_fit_to_BFBayesFactor)
135138
export(SBC_fit_to_bridge_sampler)
136139
export(SBC_fit_to_diagnostics)
@@ -176,9 +179,11 @@ export(default_diagnostic)
176179
export(default_diagnostics_types)
177180
export(derived_quantities)
178181
export(diagnostic_types)
182+
export(dquants_var_attributes)
179183
export(draws_rvars_to_standata)
180184
export(draws_rvars_to_standata_single)
181185
export(empirical_coverage)
186+
export(extract_attribute_arguments_stats)
182187
export(gaffke_ci)
183188
export(gaffke_p)
184189
export(gaffke_test)

R/backend-BayesFactor.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ SBC_fit_to_draws_matrix.SBC_fit_extractBF_comparison <- function(fit) {
8686
draws1 <- posterior::merge_chains(SBC_fit_to_draws_matrix(fit$fit1))
8787

8888
if(posterior::ndraws(draws0) != posterior::ndraws(draws1)) {
89-
warning("Unequal number of draws for each bridgesampling fit. Will subset to the smaller number.")
89+
warning("Unequal number of draws for each extractBF fit. Will subset to the smaller number.")
9090
if(posterior::ndraws(draws0) > posterior::ndraws(draws1)) {
9191
draws0 <- posterior::subset_draws(draws0, draw = 1:posterior::ndraws(draws1))
9292
} else {
@@ -100,7 +100,7 @@ SBC_fit_to_draws_matrix.SBC_fit_extractBF_comparison <- function(fit) {
100100

101101
model_draws <- rbinom(n = total_draws, size = 1, prob = prob1)
102102

103-
combined_draws <- SBC:::combine_draws_matrix_for_bf(draws0, draws1, model_draws, model_var = fit$model_var)
103+
combined_draws <- SBC:::combine_draws_matrix_for_bf(list(draws0, draws1), model_draws, model_var = fit$model_var)
104104

105105
return(combined_draws)
106106
}

R/backend-bridgesampling.R

Lines changed: 149 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,28 @@
33
#'
44
#' @param ... passed to [bridgesampling::bridge_sampler()].
55
#' @export
6-
SBC_backend_bridgesampling <- function(backend_H0, backend_H1, model_var = "model", prior_prob1 = 0.5, ...) {
6+
SBC_backend_bridgesampling <- function(..., model_var = "model", prior_probs = NULL, prior_prob1 = NULL, bridgesampling_args = list()) {
7+
78
require_package_version("bridgesampling", version = "1.0", purpose = " to use the bridgesampling SBC backend")
8-
structure(list(backend_H0 = backend_H0,
9-
backend_H1 = backend_H1,
9+
10+
all_backends <- list(...)
11+
if(!is.null(prior_prob1)) {
12+
stopifnot(length(all_backends) == 2)
13+
stopifnot(is.null(prior_probs))
14+
prior_probs <- c(1 - prior_prob1, prior_prob1)
15+
}
16+
17+
if(is.null(prior_probs)) {
18+
prior_probs <- rep(1 / length(all_backends), times = length(all_backends))
19+
} else {
20+
stopifnot(is.numeric(prior_probs))
21+
stopifnot(length(prior_probs) == length(all_datasets))
22+
}
23+
24+
structure(list(all_backends = all_backends,
1025
model_var = model_var,
11-
prior_prob1 = prior_prob1,
12-
bridgesampling_args = list(...)),
26+
prior_probs = prior_probs,
27+
bridgesampling_args = bridgesampling_args),
1328
class = "SBC_backend_bridgesampling")
1429
}
1530

@@ -84,66 +99,112 @@ SBC_fit.SBC_backend_bridgesampling <- function(backend, generated, cores) {
8499
list(fit = fit_with_outputs$res, bridge = bridge_with_outputs$res)
85100
}
86101

87-
fit_bridge_0 <- fit_single(backend$backend_H0, 0)
88-
fit_bridge_1 <- fit_single(backend$backend_H1, 1)
102+
fit_bridges <- purrr::map2(backend$all_backends, 0:(length(backend$all_backends) - 1), fit_single)
103+
104+
fits <- purrr::map(fit_bridges, \(x) x$fit)
105+
bridges <- purrr::map(fit_bridges, \(x) x$bridge)
89106

90107
structure(list(
91-
fit0 = fit_bridge_0$fit,
92-
fit1 = fit_bridge_1$fit,
93-
bridge_H0 = fit_bridge_0$bridge,
94-
bridge_H1 = fit_bridge_1$bridge,
108+
fits = fits,
109+
bridges = bridges,
95110
model_var = backend$model_var,
96-
prior_prob1 = backend$prior_prob1
111+
prior_probs = backend$prior_probs
97112
), class = "SBC_fit_bridgesampling")
98113
}
99114

100-
SBC_fit_bridgesampling_to_prob1 <- function(fit, log.p = FALSE) {
101-
bf_res <- bridgesampling::bf(fit$bridge_H0, fit$bridge_H1, log = TRUE)
102-
if(inherits(bf_res, "bf_bridge_list")) {
103-
log_bf_01 <- bf_res$bf_median_based
104-
} else {
105-
log_bf_01 <- bf_res$bf
115+
SBC_fit_bridgesampling_to_probs <- function(fit, log.p = FALSE) {
116+
# Using median-based BF when multiple bridgesampling iterations were used
117+
logmls <- purrr::map_dbl(fit$bridges, \(x) { median(x$logml, na.rm = TRUE)})
118+
119+
if(any(is.na(logmls))) {
120+
print(fit$bridges)
121+
stop("Some logml values are NA.")
106122
}
107-
if(is.na(log_bf_01)) {
108-
print(fit$bridge_H0)
109-
print(fit$bridge_H1)
110-
stop("Bayes factor is NA.")
123+
124+
prior_log <- log(fit$prior_probs)
125+
log_probs_rel <- logmls + prior_log
126+
127+
# Softmax to probs
128+
log_probs <- log_probs_rel - log_sum_exp(log_probs_rel)
129+
if(log.p) {
130+
log_probs
131+
} else {
132+
exp(log_probs)
111133
}
112-
prior_log <- log(fit$prior_prob1) - log1p( -fit$prior_prob1)
113-
prob1 <- plogis(-log_bf_01 + prior_log, log.p = log.p)
114-
return(prob1)
115134
}
116135

117136
#' @export
118137
SBC_fit_to_draws_matrix.SBC_fit_bridgesampling <- function(fit) {
119-
draws0 <- posterior::merge_chains(SBC_fit_to_draws_matrix(fit$fit0))
120-
draws1 <- posterior::merge_chains(SBC_fit_to_draws_matrix(fit$fit1))
138+
all_dms <- purrr::map(fit$fits, SBC_fit_to_draws_matrix)
139+
all_draws <- purrr::map(all_dms, posterior::merge_chains)
140+
141+
all_ndraws <- purrr::map_int(all_draws, posterior::ndraws)
142+
shared_ndraws <- unique(all_ndraws)
121143

122-
if(posterior::ndraws(draws0) != posterior::ndraws(draws1)) {
144+
if(length(shared_ndraws) > 1) {
123145
warning("Unequal number of draws for each bridgesampling fit. Will subset to the smaller number.")
124-
if(posterior::ndraws(draws0) > posterior::ndraws(draws1)) {
125-
draws0 <- posterior::subset_draws(draws0, draw = 1:posterior::ndraws(draws1))
126-
} else {
127-
draws1 <- posterior::subset_draws(draws1, draw = 1:posterior::ndraws(draws0))
128-
}
146+
shared_ndraws <- min(shared_ndraws)
147+
all_draws <- purrr::map(all_draws, \(x) {
148+
posterior::subset_draws(x, draw = 1:shared_ndraws)
149+
})
129150
}
130151

131-
prob1 <- SBC_fit_bridgesampling_to_prob1(fit)
132-
133-
total_draws <- posterior::ndraws(draws0)
152+
probs <- SBC_fit_bridgesampling_to_probs(fit)
134153

135-
model_draws <- rbinom(n = total_draws, size = 1, prob = prob1)
154+
if(length(all_draws) == 2) {
155+
# Keeping the old way for 2 models to not invalidate older results
156+
model_draws <- rbinom(n = shared_ndraws, size = 1, prob = probs[2])
157+
} else {
158+
model_draws <- sample(0:(length(all_draws) - 1), size = shared_ndraws, prob = probs, replace = TRUE)
159+
# TODO figure out top_model to have in stats (presumably via the DQ mechanism...)
160+
}
136161

137-
combined_draws <- combine_draws_matrix_for_bf(draws0, draws1, model_draws, model_var = fit$model_var)
162+
combined_draws <- combine_draws_matrix_for_bf(all_draws, model_draws, model_var = fit$model_var)
138163

139164
return(combined_draws)
140165
}
141166

167+
#' @export
168+
SBC_fit_specific_dquants.SBC_fit_bridgesampling <- function(fit) {
169+
probs <- SBC_fit_bridgesampling_to_probs(fit)
170+
max_index <- which.max(probs)
171+
is_var_name <- paste0("is_", fit$model_var, max_index - 1)
172+
top_var_name <- paste0("top_", fit$model_var)
173+
174+
dq_args <- list(rlang::parse_quo(is_var_name, rlang::current_env()))
175+
names(dq_args) <- top_var_name
176+
dq_args$.var_attributes <- var_attributes_from_list(top_var_name, list(c(
177+
binary_var_attribute(), possibly_constant_var_attribute()
178+
)))
179+
do.call(derived_quantities, dq_args)
180+
181+
}
182+
142183
#' @export
143184
SBC_posterior_cdf.SBC_fit_bridgesampling <- function(fit, variables) {
144185
if(fit$model_var %in% names(variables)) {
145-
prob1 <- SBC_fit_bridgesampling_to_prob1(fit)
146-
return(binary_to_cdf(fit$model_var, prob1, variables[fit$model_var]))
186+
probs <- SBC_fit_bridgesampling_to_probs(fit)
187+
model_cdf <- discrete_to_cdf(fit$model_var, probs, variables[fit$model_var])
188+
if(length(probs) == 2) {
189+
return(model_cdf)
190+
} else {
191+
is_model_cdf_list <- list()
192+
for(i in 1:length(probs)) {
193+
is_var_name <- paste0("is_", fit$model_var, i - 1)
194+
if(is_var_name %in% names(variables)) {
195+
is_model_cdf_list[[i]] <- binary_to_cdf(is_var_name, probs[i], variables[is_var_name])
196+
} else {
197+
is_model_cdf_list[[i]] <- NULL
198+
}
199+
}
200+
is_model_cdf <- do.call(rbind, is_model_cdf_list)
201+
202+
max_index <- which.max(probs)
203+
simulated_value_top <- variables[fit$model_var] == max_index - 1
204+
top_prediction_cdf <- binary_to_cdf(paste0("top_", fit$model_var), probs[max_index], simulated_value_top)
205+
206+
return(rbind(model_cdf, top_prediction_cdf, is_model_cdf))
207+
}
147208
} else {
148209
return(NULL)
149210
}
@@ -162,17 +223,23 @@ SBC_fit_to_diagnostics.SBC_fit_bridgesampling <- function(fit, fit_output, fit_m
162223
}
163224
}
164225

165-
diags0 <- SBC_fit_to_diagnostics(fit$fit0,
166-
get_prefixed_lines(hypothesis_output_prefix(0), fit_output),
167-
get_prefixed_lines(hypothesis_output_prefix(0), fit_messages),
168-
get_prefixed_lines(hypothesis_output_prefix(0), fit_warnings))
169-
diags1 <- SBC_fit_to_diagnostics(fit$fit1,
170-
get_prefixed_lines(hypothesis_output_prefix(1), fit_output),
171-
get_prefixed_lines(hypothesis_output_prefix(1), fit_messages),
172-
get_prefixed_lines(hypothesis_output_prefix(1), fit_warnings))
226+
process_diag_single <- function(fit, model_index) {
227+
diags <- SBC_fit_to_diagnostics(fit,
228+
get_prefixed_lines(hypothesis_output_prefix(model_index), fit_output),
229+
get_prefixed_lines(hypothesis_output_prefix(model_index), fit_messages),
230+
get_prefixed_lines(hypothesis_output_prefix(model_index), fit_warnings))
231+
if(!is.null(diags)) {
232+
names(diags) <- paste0(names(diags), "_H", model_index)
233+
}
234+
diags
235+
}
236+
237+
model_indices <- 0:(length(fit$fits) - 1)
238+
239+
diags_all <- purrr::map2(fit$fits, model_indices, process_diag_single)
173240

174-
prob1 <- SBC_fit_bridgesampling_to_prob1(fit)
175-
log_prob1 <- SBC_fit_bridgesampling_to_prob1(fit, log.p = TRUE)
241+
probs <- SBC_fit_bridgesampling_to_probs(fit)
242+
log_probs <- SBC_fit_bridgesampling_to_probs(fit, log.p = TRUE)
176243

177244
get_percentage_error <- function(bridge) {
178245
errm <- bridgesampling::error_measures(bridge)
@@ -217,39 +284,37 @@ SBC_fit_to_diagnostics.SBC_fit_bridgesampling <- function(fit, fit_output, fit_m
217284
diags
218285
}
219286

220-
diags_bs <- cbind(
221-
data.frame(prob_H1 = prob1, log_prob_H1 = log_prob1),
222-
get_bridge_diagnostics(fit$bridge_H0, 0),
223-
get_bridge_diagnostics(fit$bridge_H1, 1)
224-
)
287+
bridge_diags_all <- purrr::map2(fit$bridges, model_indices, get_bridge_diagnostics)
225288

226-
if(!is.null(diags0)) {
227-
names(diags0) <- paste0(names(diags0), "_H0")
228-
diags_bs <- cbind(diags_bs, diags0)
229-
}
230-
231-
if(!is.null(diags1)) {
232-
names(diags1) <- paste0(names(diags1), "_H1")
233-
diags_bs <- cbind(diags_bs, diags1)
289+
if(length(fit$fits) == 2) {
290+
probs_df <- data.frame(prob_H1 = probs[2], log_prob_H1 = log_probs[2])
291+
} else {
292+
probs_diag_vec <- c(probs, log_probs)
293+
names(probs_diag_vec) <- c(paste0("prob_H", model_indices), paste0("log_prob_H", model_indices))
294+
probs_df <- t(as.data.frame(probs_diag_vec))
295+
rownames(probs_df) <- NULL
234296
}
297+
diags_bs <- do.call(cbind, c(list(probs_df),
298+
bridge_diags_all, diags_all))
235299

236300
class(diags_bs) <- c("SBC_bridgesampling_diagnostics", class(diags_bs))
237-
attr(diags_bs, "submodel_classes") <- list(H0 = class(diags0), H1 = class(diags1))
301+
attr(diags_bs, "submodel_classes") <- purrr::map(diags_all, class)
238302
return(diags_bs)
239303
}
240304

241305
#' @export
242306
diagnostic_types.SBC_bridgesampling_diagnostics <- function(diags) {
243307
submodel_classes <- attr(diags, "submodel_classes", exact = TRUE)
244308
if(is.null(submodel_classes) ||
245-
!is.list(submodel_classes) || !identical(names(submodel_classes), c("H0", "H1"))) {
309+
!is.list(submodel_classes)) {
246310
warning(
247311
r"(The 'submodel_classes' attribute of an SBC_bridgesampling_diagnostics data.frame
248312
is not set or is in incorrect format.
249313
Maybe you have modified the $backend_diagnostics element of SBC_results?
250314
If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
251315
)")
252-
submodel_diags <- list()
316+
submodel_diags_all <- list()
317+
n_submodels <- 2 #Typical guess, won't hurt reporting
253318
} else {
254319
get_submodel_diags <- function(i) {
255320
bs_specific_diags <- list()
@@ -269,7 +334,7 @@ If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
269334
\(name) gsub(paste0("_H",i,"$"), "", name))
270335

271336

272-
class(H_diags) <- submodel_classes[[paste0("H",i)]]
337+
class(H_diags) <- submodel_classes[[i + 1]]
273338
types_sub <- diagnostic_types(H_diags)
274339
types_mapped <- purrr::map(types_sub,
275340
\(diag) submodel_diagnostic(paste0("H", i), diag))
@@ -280,20 +345,25 @@ If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
280345
)
281346
}
282347

283-
submodel_diags <- c(
284-
get_submodel_diags(0),
285-
get_submodel_diags(1)
286-
)
287-
348+
n_submodels <- length(submodel_classes)
349+
submodel_diags_all <- do.call(c,
350+
purrr::map(0:(n_submodels - 1), get_submodel_diags))
288351
}
289352

290-
c(
291-
list(
353+
if(n_submodels == 2) {
354+
prob_diags <- list(
292355
prob_H1 = numeric_diagnostic("posterior probability of H1", report = "quantiles"),
293356
log_prob_H1 = skip_diagnostic()
294-
),
295-
submodel_diags
296-
)
357+
)
358+
} else {
359+
prob_diags <- list()
360+
for(i in 0:(n_submodels - 1)) {
361+
prob_diags[[paste0("prob_H", i)]] <- numeric_diagnostic(paste0("posterior probability of H",i), report = "quantiles")
362+
prob_diags[[paste0("log_prob_H", i)]] <- skip_diagnostic()
363+
}
364+
}
365+
366+
c(prob_diags, submodel_diags_all)
297367
}
298368

299369
#' Custom rbind implementation maintainig information about submodels
@@ -326,7 +396,7 @@ rbind.SBC_bridgesampling_diagnostics <- function(...) {
326396
res
327397
}
328398

329-
#' Custom select implementation maintainig information about submodels
399+
#' Custom select implementation maintaining information about submodels
330400
#' @exportS3Method dplyr::select
331401
select.SBC_bridgesampling_diagnostics <- function(diags, ...) {
332402
selected <- NextMethod()
@@ -337,11 +407,6 @@ select.SBC_bridgesampling_diagnostics <- function(diags, ...) {
337407
#' @export
338408
SBC_backend_hash_for_cache.SBC_backend_bridgesampling <- function(backend) {
339409
backend_for_hash <- backend
340-
backend_for_hash$backend_H0 <- SBC_backend_hash_for_cache(backend$backend_H0)
341-
backend_for_hash$backend_H1 <- SBC_backend_hash_for_cache(backend$backend_H1)
342-
# Keep caches from older versions valid
343-
if(!is.null(backend$prior_prob1) && backend$prior_prob1 == 0.5) {
344-
backend_for_hash$prior_prob1 <- NULL
345-
}
410+
backend_for_hash$all_backends <- purrr::map(backend$all_backends, SBC_backend_hash_for_cache)
346411
rlang::hash(backend_for_hash)
347412
}

0 commit comments

Comments
 (0)