33# '
44# ' @param ... passed to [bridgesampling::bridge_sampler()].
55# ' @export
6- SBC_backend_bridgesampling <- function (backend_H0 , backend_H1 , model_var = " model" , prior_prob1 = 0.5 , ... ) {
6+ SBC_backend_bridgesampling <- function (... , model_var = " model" , prior_probs = NULL , prior_prob1 = NULL , bridgesampling_args = list ()) {
7+
78 require_package_version(" bridgesampling" , version = " 1.0" , purpose = " to use the bridgesampling SBC backend" )
8- structure(list (backend_H0 = backend_H0 ,
9- backend_H1 = backend_H1 ,
9+
10+ all_backends <- list (... )
11+ if (! is.null(prior_prob1 )) {
12+ stopifnot(length(all_backends ) == 2 )
13+ stopifnot(is.null(prior_probs ))
14+ prior_probs <- c(1 - prior_prob1 , prior_prob1 )
15+ }
16+
17+ if (is.null(prior_probs )) {
18+ prior_probs <- rep(1 / length(all_backends ), times = length(all_backends ))
19+ } else {
20+ stopifnot(is.numeric(prior_probs ))
21+ stopifnot(length(prior_probs ) == length(all_datasets ))
22+ }
23+
24+ structure(list (all_backends = all_backends ,
1025 model_var = model_var ,
11- prior_prob1 = prior_prob1 ,
12- bridgesampling_args = list ( ... ) ),
26+ prior_probs = prior_probs ,
27+ bridgesampling_args = bridgesampling_args ),
1328 class = " SBC_backend_bridgesampling" )
1429}
1530
@@ -84,66 +99,112 @@ SBC_fit.SBC_backend_bridgesampling <- function(backend, generated, cores) {
8499 list (fit = fit_with_outputs $ res , bridge = bridge_with_outputs $ res )
85100 }
86101
87- fit_bridge_0 <- fit_single(backend $ backend_H0 , 0 )
88- fit_bridge_1 <- fit_single(backend $ backend_H1 , 1 )
102+ fit_bridges <- purrr :: map2(backend $ all_backends , 0 : (length(backend $ all_backends ) - 1 ), fit_single )
103+
104+ fits <- purrr :: map(fit_bridges , \(x ) x $ fit )
105+ bridges <- purrr :: map(fit_bridges , \(x ) x $ bridge )
89106
90107 structure(list (
91- fit0 = fit_bridge_0 $ fit ,
92- fit1 = fit_bridge_1 $ fit ,
93- bridge_H0 = fit_bridge_0 $ bridge ,
94- bridge_H1 = fit_bridge_1 $ bridge ,
108+ fits = fits ,
109+ bridges = bridges ,
95110 model_var = backend $ model_var ,
96- prior_prob1 = backend $ prior_prob1
111+ prior_probs = backend $ prior_probs
97112 ), class = " SBC_fit_bridgesampling" )
98113}
99114
100- SBC_fit_bridgesampling_to_prob1 <- function (fit , log.p = FALSE ) {
101- bf_res <- bridgesampling :: bf(fit $ bridge_H0 , fit $ bridge_H1 , log = TRUE )
102- if (inherits(bf_res , " bf_bridge_list" )) {
103- log_bf_01 <- bf_res $ bf_median_based
104- } else {
105- log_bf_01 <- bf_res $ bf
115+ SBC_fit_bridgesampling_to_probs <- function (fit , log.p = FALSE ) {
116+ # Using median-based BF when multiple bridgesampling iterations were used
117+ logmls <- purrr :: map_dbl(fit $ bridges , \(x ) { median(x $ logml , na.rm = TRUE )})
118+
119+ if (any(is.na(logmls ))) {
120+ print(fit $ bridges )
121+ stop(" Some logml values are NA." )
106122 }
107- if (is.na(log_bf_01 )) {
108- print(fit $ bridge_H0 )
109- print(fit $ bridge_H1 )
110- stop(" Bayes factor is NA." )
123+
124+ prior_log <- log(fit $ prior_probs )
125+ log_probs_rel <- logmls + prior_log
126+
127+ # Softmax to probs
128+ log_probs <- log_probs_rel - log_sum_exp(log_probs_rel )
129+ if (log.p ) {
130+ log_probs
131+ } else {
132+ exp(log_probs )
111133 }
112- prior_log <- log(fit $ prior_prob1 ) - log1p( - fit $ prior_prob1 )
113- prob1 <- plogis(- log_bf_01 + prior_log , log.p = log.p )
114- return (prob1 )
115134}
116135
117136# ' @export
118137SBC_fit_to_draws_matrix.SBC_fit_bridgesampling <- function (fit ) {
119- draws0 <- posterior :: merge_chains(SBC_fit_to_draws_matrix(fit $ fit0 ))
120- draws1 <- posterior :: merge_chains(SBC_fit_to_draws_matrix(fit $ fit1 ))
138+ all_dms <- purrr :: map(fit $ fits , SBC_fit_to_draws_matrix )
139+ all_draws <- purrr :: map(all_dms , posterior :: merge_chains )
140+
141+ all_ndraws <- purrr :: map_int(all_draws , posterior :: ndraws )
142+ shared_ndraws <- unique(all_ndraws )
121143
122- if (posterior :: ndraws( draws0 ) != posterior :: ndraws( draws1 ) ) {
144+ if (length( shared_ndraws ) > 1 ) {
123145 warning(" Unequal number of draws for each bridgesampling fit. Will subset to the smaller number." )
124- if (posterior :: ndraws(draws0 ) > posterior :: ndraws(draws1 )) {
125- draws0 <- posterior :: subset_draws(draws0 , draw = 1 : posterior :: ndraws(draws1 ))
126- } else {
127- draws1 <- posterior :: subset_draws(draws1 , draw = 1 : posterior :: ndraws(draws0 ))
128- }
146+ shared_ndraws <- min(shared_ndraws )
147+ all_draws <- purrr :: map(all_draws , \(x ) {
148+ posterior :: subset_draws(x , draw = 1 : shared_ndraws )
149+ })
129150 }
130151
131- prob1 <- SBC_fit_bridgesampling_to_prob1(fit )
132-
133- total_draws <- posterior :: ndraws(draws0 )
152+ probs <- SBC_fit_bridgesampling_to_probs(fit )
134153
135- model_draws <- rbinom(n = total_draws , size = 1 , prob = prob1 )
154+ if (length(all_draws ) == 2 ) {
155+ # Keeping the old way for 2 models to not invalidate older results
156+ model_draws <- rbinom(n = shared_ndraws , size = 1 , prob = probs [2 ])
157+ } else {
158+ model_draws <- sample(0 : (length(all_draws ) - 1 ), size = shared_ndraws , prob = probs , replace = TRUE )
159+ # TODO figure out top_model to have in stats (presumably via the DQ mechanism...)
160+ }
136161
137- combined_draws <- combine_draws_matrix_for_bf(draws0 , draws1 , model_draws , model_var = fit $ model_var )
162+ combined_draws <- combine_draws_matrix_for_bf(all_draws , model_draws , model_var = fit $ model_var )
138163
139164 return (combined_draws )
140165}
141166
167+ # ' @export
168+ SBC_fit_specific_dquants.SBC_fit_bridgesampling <- function (fit ) {
169+ probs <- SBC_fit_bridgesampling_to_probs(fit )
170+ max_index <- which.max(probs )
171+ is_var_name <- paste0(" is_" , fit $ model_var , max_index - 1 )
172+ top_var_name <- paste0(" top_" , fit $ model_var )
173+
174+ dq_args <- list (rlang :: parse_quo(is_var_name , rlang :: current_env()))
175+ names(dq_args ) <- top_var_name
176+ dq_args $ .var_attributes <- var_attributes_from_list(top_var_name , list (c(
177+ binary_var_attribute(), possibly_constant_var_attribute()
178+ )))
179+ do.call(derived_quantities , dq_args )
180+
181+ }
182+
142183# ' @export
143184SBC_posterior_cdf.SBC_fit_bridgesampling <- function (fit , variables ) {
144185 if (fit $ model_var %in% names(variables )) {
145- prob1 <- SBC_fit_bridgesampling_to_prob1(fit )
146- return (binary_to_cdf(fit $ model_var , prob1 , variables [fit $ model_var ]))
186+ probs <- SBC_fit_bridgesampling_to_probs(fit )
187+ model_cdf <- discrete_to_cdf(fit $ model_var , probs , variables [fit $ model_var ])
188+ if (length(probs ) == 2 ) {
189+ return (model_cdf )
190+ } else {
191+ is_model_cdf_list <- list ()
192+ for (i in 1 : length(probs )) {
193+ is_var_name <- paste0(" is_" , fit $ model_var , i - 1 )
194+ if (is_var_name %in% names(variables )) {
195+ is_model_cdf_list [[i ]] <- binary_to_cdf(is_var_name , probs [i ], variables [is_var_name ])
196+ } else {
197+ is_model_cdf_list [[i ]] <- NULL
198+ }
199+ }
200+ is_model_cdf <- do.call(rbind , is_model_cdf_list )
201+
202+ max_index <- which.max(probs )
203+ simulated_value_top <- variables [fit $ model_var ] == max_index - 1
204+ top_prediction_cdf <- binary_to_cdf(paste0(" top_" , fit $ model_var ), probs [max_index ], simulated_value_top )
205+
206+ return (rbind(model_cdf , top_prediction_cdf , is_model_cdf ))
207+ }
147208 } else {
148209 return (NULL )
149210 }
@@ -162,17 +223,23 @@ SBC_fit_to_diagnostics.SBC_fit_bridgesampling <- function(fit, fit_output, fit_m
162223 }
163224 }
164225
165- diags0 <- SBC_fit_to_diagnostics(fit $ fit0 ,
166- get_prefixed_lines(hypothesis_output_prefix(0 ), fit_output ),
167- get_prefixed_lines(hypothesis_output_prefix(0 ), fit_messages ),
168- get_prefixed_lines(hypothesis_output_prefix(0 ), fit_warnings ))
169- diags1 <- SBC_fit_to_diagnostics(fit $ fit1 ,
170- get_prefixed_lines(hypothesis_output_prefix(1 ), fit_output ),
171- get_prefixed_lines(hypothesis_output_prefix(1 ), fit_messages ),
172- get_prefixed_lines(hypothesis_output_prefix(1 ), fit_warnings ))
226+ process_diag_single <- function (fit , model_index ) {
227+ diags <- SBC_fit_to_diagnostics(fit ,
228+ get_prefixed_lines(hypothesis_output_prefix(model_index ), fit_output ),
229+ get_prefixed_lines(hypothesis_output_prefix(model_index ), fit_messages ),
230+ get_prefixed_lines(hypothesis_output_prefix(model_index ), fit_warnings ))
231+ if (! is.null(diags )) {
232+ names(diags ) <- paste0(names(diags ), " _H" , model_index )
233+ }
234+ diags
235+ }
236+
237+ model_indices <- 0 : (length(fit $ fits ) - 1 )
238+
239+ diags_all <- purrr :: map2(fit $ fits , model_indices , process_diag_single )
173240
174- prob1 <- SBC_fit_bridgesampling_to_prob1 (fit )
175- log_prob1 <- SBC_fit_bridgesampling_to_prob1 (fit , log.p = TRUE )
241+ probs <- SBC_fit_bridgesampling_to_probs (fit )
242+ log_probs <- SBC_fit_bridgesampling_to_probs (fit , log.p = TRUE )
176243
177244 get_percentage_error <- function (bridge ) {
178245 errm <- bridgesampling :: error_measures(bridge )
@@ -217,39 +284,37 @@ SBC_fit_to_diagnostics.SBC_fit_bridgesampling <- function(fit, fit_output, fit_m
217284 diags
218285 }
219286
220- diags_bs <- cbind(
221- data.frame (prob_H1 = prob1 , log_prob_H1 = log_prob1 ),
222- get_bridge_diagnostics(fit $ bridge_H0 , 0 ),
223- get_bridge_diagnostics(fit $ bridge_H1 , 1 )
224- )
287+ bridge_diags_all <- purrr :: map2(fit $ bridges , model_indices , get_bridge_diagnostics )
225288
226- if (! is.null(diags0 )) {
227- names(diags0 ) <- paste0(names(diags0 ), " _H0" )
228- diags_bs <- cbind(diags_bs , diags0 )
229- }
230-
231- if (! is.null(diags1 )) {
232- names(diags1 ) <- paste0(names(diags1 ), " _H1" )
233- diags_bs <- cbind(diags_bs , diags1 )
289+ if (length(fit $ fits ) == 2 ) {
290+ probs_df <- data.frame (prob_H1 = probs [2 ], log_prob_H1 = log_probs [2 ])
291+ } else {
292+ probs_diag_vec <- c(probs , log_probs )
293+ names(probs_diag_vec ) <- c(paste0(" prob_H" , model_indices ), paste0(" log_prob_H" , model_indices ))
294+ probs_df <- t(as.data.frame(probs_diag_vec ))
295+ rownames(probs_df ) <- NULL
234296 }
297+ diags_bs <- do.call(cbind , c(list (probs_df ),
298+ bridge_diags_all , diags_all ))
235299
236300 class(diags_bs ) <- c(" SBC_bridgesampling_diagnostics" , class(diags_bs ))
237- attr(diags_bs , " submodel_classes" ) <- list ( H0 = class( diags0 ), H1 = class( diags1 ) )
301+ attr(diags_bs , " submodel_classes" ) <- purrr :: map( diags_all , class )
238302 return (diags_bs )
239303}
240304
241305# ' @export
242306diagnostic_types.SBC_bridgesampling_diagnostics <- function (diags ) {
243307 submodel_classes <- attr(diags , " submodel_classes" , exact = TRUE )
244308 if (is.null(submodel_classes ) ||
245- ! is.list(submodel_classes ) || ! identical(names( submodel_classes ), c( " H0 " , " H1 " )) ) {
309+ ! is.list(submodel_classes )) {
246310 warning(
247311r " (The 'submodel_classes' attribute of an SBC_bridgesampling_diagnostics data.frame
248312is not set or is in incorrect format.
249313Maybe you have modified the $backend_diagnostics element of SBC_results?
250314If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
251315)" )
252- submodel_diags <- list ()
316+ submodel_diags_all <- list ()
317+ n_submodels <- 2 # Typical guess, won't hurt reporting
253318 } else {
254319 get_submodel_diags <- function (i ) {
255320 bs_specific_diags <- list ()
@@ -269,7 +334,7 @@ If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
269334 \(name ) gsub(paste0(" _H" ,i ," $" ), " " , name ))
270335
271336
272- class(H_diags ) <- submodel_classes [[paste0( " H " , i ) ]]
337+ class(H_diags ) <- submodel_classes [[i + 1 ]]
273338 types_sub <- diagnostic_types(H_diags )
274339 types_mapped <- purrr :: map(types_sub ,
275340 \(diag ) submodel_diagnostic(paste0(" H" , i ), diag ))
@@ -280,20 +345,25 @@ If not, please file an issue at https://github.com/hyunjimoon/SBC/issues/
280345 )
281346 }
282347
283- submodel_diags <- c(
284- get_submodel_diags(0 ),
285- get_submodel_diags(1 )
286- )
287-
348+ n_submodels <- length(submodel_classes )
349+ submodel_diags_all <- do.call(c ,
350+ purrr :: map(0 : (n_submodels - 1 ), get_submodel_diags ))
288351 }
289352
290- c(
291- list (
353+ if ( n_submodels == 2 ) {
354+ prob_diags <- list (
292355 prob_H1 = numeric_diagnostic(" posterior probability of H1" , report = " quantiles" ),
293356 log_prob_H1 = skip_diagnostic()
294- ),
295- submodel_diags
296- )
357+ )
358+ } else {
359+ prob_diags <- list ()
360+ for (i in 0 : (n_submodels - 1 )) {
361+ prob_diags [[paste0(" prob_H" , i )]] <- numeric_diagnostic(paste0(" posterior probability of H" ,i ), report = " quantiles" )
362+ prob_diags [[paste0(" log_prob_H" , i )]] <- skip_diagnostic()
363+ }
364+ }
365+
366+ c(prob_diags , submodel_diags_all )
297367}
298368
299369# ' Custom rbind implementation maintainig information about submodels
@@ -326,7 +396,7 @@ rbind.SBC_bridgesampling_diagnostics <- function(...) {
326396 res
327397}
328398
329- # ' Custom select implementation maintainig information about submodels
399+ # ' Custom select implementation maintaining information about submodels
330400# ' @exportS3Method dplyr::select
331401select.SBC_bridgesampling_diagnostics <- function (diags , ... ) {
332402 selected <- NextMethod()
@@ -337,11 +407,6 @@ select.SBC_bridgesampling_diagnostics <- function(diags, ...) {
337407# ' @export
338408SBC_backend_hash_for_cache.SBC_backend_bridgesampling <- function (backend ) {
339409 backend_for_hash <- backend
340- backend_for_hash $ backend_H0 <- SBC_backend_hash_for_cache(backend $ backend_H0 )
341- backend_for_hash $ backend_H1 <- SBC_backend_hash_for_cache(backend $ backend_H1 )
342- # Keep caches from older versions valid
343- if (! is.null(backend $ prior_prob1 ) && backend $ prior_prob1 == 0.5 ) {
344- backend_for_hash $ prior_prob1 <- NULL
345- }
410+ backend_for_hash $ all_backends <- purrr :: map(backend $ all_backends , SBC_backend_hash_for_cache )
346411 rlang :: hash(backend_for_hash )
347412}
0 commit comments