Skip to content

Commit a170ea8

Browse files
authored
Improve efficiency of clean_parameters() for more complex *brms* models (#1082)
* Improve efficiency of `clean_parameters()` for more complex *brms* models * fix
1 parent 0ff119f commit a170ea8

4 files changed

Lines changed: 53 additions & 18 deletions

File tree

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: insight
33
Title: Easy Access to Model Information for Various Model Objects
4-
Version: 1.3.0.5
4+
Version: 1.3.0.6
55
Authors@R:
66
c(person(given = "Daniel",
77
family = "Lüdecke",

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
* Better support for models of class `sdmTMB`.
1616

17+
* Improve efficiency of `clean_parameters()` for more complex *brms* models.
18+
1719
## Bug fixes
1820

1921
* Fixed issue in `clean_names()` for *brms* models with `mm()` in formula.

R/clean_parameters.R

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,20 @@ clean_parameters.mlm <- function(x, ...) {
502502
}
503503
}
504504

505+
# multimembership / mixture?
506+
model_fam <- get_family(x)
507+
if (inherits(model_fam, "brmsfamily") && model_fam$family == "mixture") {
508+
class_params <- grepl("^b_mu\\d+_(.*)", out$Parameter)
509+
if (any(class_params)) {
510+
out$Group[class_params] <- paste(
511+
"Class",
512+
gsub("^b_mu(\\d+)_(.*)", "\\1", out$Parameter[class_params])
513+
)
514+
}
515+
} else {
516+
class_params <- NULL
517+
}
518+
505519
# retrieve auxiliary components
506520
dpars <- find_auxiliary(x)
507521

@@ -524,6 +538,15 @@ clean_parameters.mlm <- function(x, ...) {
524538
out$Cleaned_Parameter <- gsub(pattern = "^(b_|bs_|bsp_|bcs_)(?!zi_)(.*)", "\\2", out$Cleaned_Parameter, perl = TRUE)
525539
out$Cleaned_Parameter <- gsub(pattern = "^(b_zi_|bs_zi_|bsp_zi_|bcs_zi_)(.*)", "\\2", out$Cleaned_Parameter) # nolint
526540

541+
if (!is.null(class_params)) {
542+
# remove "mu<number>_" from parameters
543+
out$Cleaned_Parameter[class_params] <- gsub(
544+
"^mu(\\d+)_(.*)",
545+
"\\2",
546+
out$Cleaned_Parameter[class_params]
547+
)
548+
}
549+
527550
# correlation and sd
528551

529552
cor_sd <- grepl("(sd_|cor_)(.*)", out$Cleaned_Parameter)
@@ -620,7 +643,11 @@ clean_parameters.mlm <- function(x, ...) {
620643
out$Cleaned_Parameter[intercepts] <- "(Intercept)"
621644
}
622645

623-
interaction_terms <- grep(".", out$Cleaned_Parameter, fixed = TRUE)
646+
interaction_terms <- grep(
647+
".",
648+
out$Cleaned_Parameter[out$Effects != "random"],
649+
fixed = TRUE
650+
)
624651

625652
if (length(interaction_terms)) {
626653
for (i in interaction_terms) {

R/get_predicted_bayesian.R

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# Bayesian --------------------------------------------------------------
22
# =======================================================================
33

4-
54
#' @rdname get_predicted
65
#' @export
7-
get_predicted.stanreg <- function(x,
8-
data = NULL,
9-
predict = "expectation",
10-
iterations = NULL,
11-
ci = NULL,
12-
ci_method = NULL,
13-
include_random = "default",
14-
include_smooth = TRUE,
15-
verbose = TRUE,
16-
...) {
6+
get_predicted.stanreg <- function(
7+
x,
8+
data = NULL,
9+
predict = "expectation",
10+
iterations = NULL,
11+
ci = NULL,
12+
ci_method = NULL,
13+
include_random = "default",
14+
include_smooth = TRUE,
15+
verbose = TRUE,
16+
...) {
1717
check_if_installed("rstantools")
1818

1919
if (is.null(ci_method)) {
@@ -51,7 +51,8 @@ get_predicted.stanreg <- function(x,
5151
# }
5252

5353
# prepare arguments, avoid possible matching by multiple actual arguments
54-
fun_args <- list(x,
54+
fun_args <- list(
55+
x,
5556
newdata = my_args$data,
5657
re.form = my_args$re.form,
5758
dpar = my_args$distributional_parameter,
@@ -82,8 +83,10 @@ get_predicted.stanreg <- function(x,
8283
model_family <- get_family(x)
8384
# exceptions
8485
is_wiener <- inherits(model_family, "brmsfamily") && model_family$family == "wiener"
85-
is_rtchoice <- model_family$family == "custom" && model_family$name == "lnr"
86-
is_mixture <- model_family$family == "mixture"
86+
is_rtchoice <- inherits(model_family, "brmsfamily") &&
87+
model_family$family == "custom" &&
88+
model_family$name == "lnr"
89+
is_mixture <- inherits(model_family, "brmsfamily") && model_family$family == "mixture"
8790

8891
# Special case for rwiener (get choice 1 as negative values)
8992
# Note that for mv models, x$family returns a list of families
@@ -101,7 +104,10 @@ get_predicted.stanreg <- function(x,
101104
}
102105

103106
# Handle special cases
104-
if (!my_args$predict %in% c("expectation", "response", "link") && inherits(model_family, "brmsfamily")) {
107+
if (
108+
!my_args$predict %in% c("expectation", "response", "link") &&
109+
inherits(model_family, "brmsfamily")
110+
) {
105111
if (is_wiener) {
106112
# Wiener (Drift Diffusion) Models --------------------
107113
# ----------------------------------------------------
@@ -146,7 +152,7 @@ get_predicted.stanreg <- function(x,
146152
# pp_mixture returns an array with probs, SE and intervals.
147153
# if requested, we extract the intervals here for the "ci_data"
148154
# data.frame
149-
res <- lapply (seq_len(nrow(mixture_output)), function(i) {
155+
res <- lapply(seq_len(nrow(mixture_output)), function(i) {
150156
max_prob <- which.max(mixture_output[i, 1, ])
151157
data.frame(
152158
Probability = mixture_output[i, 1, max_prob],

0 commit comments

Comments
 (0)