The bonsaiforest2 package is used for Bayesian shrinkage estimation of
subgroup treatment effects in randomized clinical trials. It supports
both one-way models (random effects) and fixed effects modeling
approaches for estimating treatment-by-subgroup interactions, with
built-in support for continuous, binary, time-to-event (Cox), and count
outcomes. The package allows the usage of state-of-the-art shrinkage
priors including Regularized Horseshoe and R2D2, combined with
standardization (G-computation) to provide interpretable marginal
treatment effects. By leveraging brms and Stan, bonsaiforest2
provides a practical tool for obtaining more stable and reliable
subgroup effect estimates in exploratory and confirmatory analyses.
UPDATE TO USUAL INSTALLATION You can install the development version
of bonsaiforest2 from its GitLab repository:
# install.packages("remotes")
remotes::install_github("openpharma/bonsaiforest2")This example demonstrates Bayesian shrinkage estimation of treatment effects across subgroups using a fixed effects model with a Regularized Horseshoe prior. We simulate data with real treatment effect heterogeneity across five subgrouping variables.
library(bonsaiforest2)
# 1. Simulate continuous outcome trial data with treatment heterogeneity
set.seed(123)
n <- 500
dat <- data.frame(
y = rnorm(n, mean = 50, sd = 15),
trt = rep(0:1, length.out = n),
x1 = factor(sample(c("A", "B"), n, replace = TRUE)),
x2 = factor(sample(c("A", "B", "C"), n, replace = TRUE)),
x3 = factor(sample(c("A", "B"), n, replace = TRUE)),
x4 = factor(sample(c("A", "B"), n, replace = TRUE)),
x5 = factor(sample(c("A", "B"), n, replace = TRUE))
)
# Add baseline effect and heterogeneous treatment effects by subgroup
trt_effect <- dat$trt * (
0.35 + # Base treatment effect
0.35 * (as.numeric(dat$x1 == "B") - 0.35) # Heterogeneity by x1
)
dat$y <- trt_effect + rnorm(n, 0, 1.2)
# 2. Fit fixed effects model with heterogeneous shrinkage
fit_fixed <- run_brms_analysis(
data = dat,
response_type = "continuous",
response_formula = y ~ trt,
unshrunk_terms_formula = ~ x1 + x2 + x3 + x4 + x5,
shrunk_prognostic_formula = NULL,
shrunk_predictive_formula = ~ 0 + trt:x1 + trt:x2 + trt:x3 + trt:x4 + trt:x5,
intercept_prior = "normal(50, 15)",
unshrunk_prior = "normal(0, 5)",
shrunk_prognostic_prior = NULL,
shrunk_predictive_prior = "horseshoe(1)",
chains = 2, iter = 1000, warmup = 500,
backend = "cmdstanr", refresh = 0
)
#> Running MCMC with 2 sequential chains...
#>
#> Chain 1 finished in 2.1 seconds.
#> Chain 2 finished in 2.0 seconds.
#>
#> Both chains finished successfully.
#> Mean chain execution time: 2.0 seconds.
#> Total execution time: 4.3 seconds.
# 3. Extract and visualize marginal treatment effects by subgroup
subgroup_effects <- summary_subgroup_effects(fit_fixed)
print(subgroup_effects)
#> $estimates
#> # A tibble: 11 × 4
#> Subgroup Median CI_Lower CI_Upper
#> <chr> <dbl> <dbl> <dbl>
#> 1 x1: A 0.212 -0.113 0.501
#> 2 x1: B 0.554 0.249 0.908
#> 3 x2: A 0.412 0.161 0.739
#> 4 x2: B 0.325 -0.0227 0.581
#> 5 x2: C 0.398 0.140 0.701
#> 6 x3: A 0.376 0.120 0.616
#> 7 x3: B 0.369 0.127 0.609
#> 8 x4: A 0.382 0.163 0.632
#> 9 x4: B 0.360 0.0824 0.619
#> 10 x5: A 0.401 0.168 0.703
#> 11 x5: B 0.344 0.0918 0.585
#>
#> $response_type
#> [1] "continuous"
#>
#> $ci_level
#> [1] 0.95
#>
#> $trt_var
#> [1] "trt"
#>
#> attr(,"class")
#> [1] "subgroup_summary"
plot(subgroup_effects)
