Skip to content

Commit ed61b84

Browse files
authored
Better support for models of class sdmTMB (#1075)
* Better support for models of class `sdmTMB` * supported list * supported model list * readme * more support * fix * fix * find_predictors * find_predictors * model info * more * disable old tests * get_predicted for sdmTMB * docs * remove unused code * add to suggest * styler * fix * fix * new range option * add getvarcov * tests * fix * remove commented code * skip on ubuntu
1 parent 785c6c6 commit ed61b84

26 files changed

Lines changed: 838 additions & 218 deletions

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ Suggests:
195195
rstudioapi,
196196
RWiener,
197197
sandwich,
198+
sdmTMB,
198199
serp,
199200
speedglm,
200201
splines,

NAMESPACE

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ S3method(find_formula,poissonmfx)
178178
S3method(find_formula,probitmfx)
179179
S3method(find_formula,rlmerMod)
180180
S3method(find_formula,rma)
181+
S3method(find_formula,sdmTMB)
181182
S3method(find_formula,selection)
182183
S3method(find_formula,sem)
183184
S3method(find_formula,stanmvreg)
@@ -304,6 +305,7 @@ S3method(find_parameters,rms)
304305
S3method(find_parameters,rqs)
305306
S3method(find_parameters,rqss)
306307
S3method(find_parameters,scam)
308+
S3method(find_parameters,sdmTMB)
307309
S3method(find_parameters,selection)
308310
S3method(find_parameters,sem)
309311
S3method(find_parameters,sim)
@@ -329,6 +331,7 @@ S3method(find_predictors,default)
329331
S3method(find_predictors,fixest)
330332
S3method(find_predictors,insight_formula)
331333
S3method(find_predictors,logitr)
334+
S3method(find_predictors,sdmTMB)
332335
S3method(find_predictors,selection)
333336
S3method(find_random,afex_aov)
334337
S3method(find_random,default)
@@ -577,6 +580,7 @@ S3method(get_dispersion,model_fit)
577580
S3method(get_family,default)
578581
S3method(get_family,list)
579582
S3method(get_family,model_fit)
583+
S3method(get_family,sdmTMB)
580584
S3method(get_intercept,default)
581585
S3method(get_intercept,stanreg)
582586
S3method(get_loglikelihood,afex_aov)
@@ -743,6 +747,7 @@ S3method(get_parameters,rms)
743747
S3method(get_parameters,rqs)
744748
S3method(get_parameters,rqss)
745749
S3method(get_parameters,scam)
750+
S3method(get_parameters,sdmTMB)
746751
S3method(get_parameters,selection)
747752
S3method(get_parameters,sem)
748753
S3method(get_parameters,sim)
@@ -794,6 +799,7 @@ S3method(get_predicted,prcomp)
794799
S3method(get_predicted,principal)
795800
S3method(get_predicted,rlm)
796801
S3method(get_predicted,rma)
802+
S3method(get_predicted,sdmTMB)
797803
S3method(get_predicted,stanreg)
798804
S3method(get_predicted,survreg)
799805
S3method(get_predicted,zeroinfl)
@@ -963,6 +969,7 @@ S3method(get_statistic,rq)
963969
S3method(get_statistic,rqs)
964970
S3method(get_statistic,rqss)
965971
S3method(get_statistic,scam)
972+
S3method(get_statistic,sdmTMB)
966973
S3method(get_statistic,selection)
967974
S3method(get_statistic,sem)
968975
S3method(get_statistic,summary.lm)
@@ -1055,6 +1062,7 @@ S3method(get_varcov,probitmfx)
10551062
S3method(get_varcov,robmixglm)
10561063
S3method(get_varcov,rq)
10571064
S3method(get_varcov,rqs)
1065+
S3method(get_varcov,sdmTMB)
10581066
S3method(get_varcov,selection)
10591067
S3method(get_varcov,tobit)
10601068
S3method(get_varcov,truncreg)
@@ -1205,6 +1213,7 @@ S3method(link_function,robmixglm)
12051213
S3method(link_function,rq)
12061214
S3method(link_function,rqs)
12071215
S3method(link_function,rqss)
1216+
S3method(link_function,sdmTMB)
12081217
S3method(link_function,serp)
12091218
S3method(link_function,speedglm)
12101219
S3method(link_function,speedlm)
@@ -1336,6 +1345,7 @@ S3method(link_inverse,robmixglm)
13361345
S3method(link_inverse,rq)
13371346
S3method(link_inverse,rqs)
13381347
S3method(link_inverse,rqss)
1348+
S3method(link_inverse,sdmTMB)
13391349
S3method(link_inverse,serp)
13401350
S3method(link_inverse,speedglm)
13411351
S3method(link_inverse,speedlm)
@@ -1492,6 +1502,7 @@ S3method(model_info,robmixglm)
14921502
S3method(model_info,rq)
14931503
S3method(model_info,rqs)
14941504
S3method(model_info,rqss)
1505+
S3method(model_info,sdmTMB)
14951506
S3method(model_info,serp)
14961507
S3method(model_info,speedglm)
14971508
S3method(model_info,speedlm)

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# insight (devel)
22

3+
## Changes
4+
5+
* The `range` argument in `get_datagrid()` gets a new option, `"pretty"`, to
6+
create a range of pretty values.
7+
8+
* Better support for models of class `sdmTMB`.
9+
310
## Bug fixes
411

512
* Fixed issue in `clean_names()` for *brms* models with `mm()` in formula.

R/compute_variances.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@
163163
if (component %in% c("residual", "distribution", "all")) {
164164
var.distribution <- .compute_variance_distribution(
165165
model,
166-
var_cor = mixed_effects_info$vc,
167166
faminfo,
168167
model_null = model_null,
169168
revar_null = var.random_null,
@@ -586,7 +585,6 @@
586585
# different values for the log/delta/trigamma approximation.
587586
# -----------------------------------------------------------------------------
588587
.compute_variance_distribution <- function(model,
589-
var_cor,
590588
faminfo,
591589
model_null = NULL,
592590
revar_null = NULL,

R/find_formula.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,36 @@ find_formula.glmmTMB <- function(x, verbose = TRUE, ...) {
12961296
}
12971297

12981298

1299+
#' @export
1300+
find_formula.sdmTMB <- function(x, verbose = TRUE, ...) {
1301+
f.cond <- stats::formula(x)
1302+
1303+
# sanity check, we might have a list of formulas, only need first element
1304+
if (is.list(f.cond)) {
1305+
f.cond <- f.cond[[1]]
1306+
}
1307+
1308+
# extract random parts of formula
1309+
f.random <- lapply(.findbars(f.cond), function(.x) {
1310+
f <- safe_deparse(.x)
1311+
stats::as.formula(paste0("~", f))
1312+
})
1313+
1314+
if (length(f.random) == 1L) {
1315+
f.random <- f.random[[1]]
1316+
}
1317+
1318+
# extract fixed effects parts
1319+
f.cond <- stats::as.formula(.get_fixed_effects(f.cond))
1320+
1321+
f <- compact_list(list(
1322+
conditional = f.cond,
1323+
random = f.random
1324+
))
1325+
.find_formula_return(f, verbose = verbose)
1326+
}
1327+
1328+
12991329
#' @export
13001330
find_formula.nlmerMod <- function(x, verbose = TRUE, ...) {
13011331
f.random <- lapply(.findbars(stats::formula(x)), function(.x) {

R/find_parameters_other.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,39 @@ find_parameters.glmgee <- function(x, component = "all", flatten = FALSE, ...) {
6363
}
6464

6565

66+
#' @export
67+
find_parameters.sdmTMB <- function(x, component = "all", flatten = FALSE, ...) {
68+
delta_comp <- isTRUE(x$family$delta)
69+
valid_comp <- compact_character(c("all", "conditional", ifelse(delta_comp, "delta", "")))
70+
component <- validate_argument(component, valid_comp)
71+
72+
cf <- suppressMessages(stats::coef(x, model = 1))
73+
conditional <- names(cf)
74+
75+
if (delta_comp) {
76+
cf <- suppressMessages(stats::coef(x, model = 2))
77+
delta <- names(cf)
78+
}
79+
80+
if (delta_comp) {
81+
out <- list(
82+
conditional = text_remove_backticks(conditional),
83+
delta = text_remove_backticks(delta)
84+
)
85+
} else {
86+
out <- list(conditional = text_remove_backticks(conditional))
87+
}
88+
89+
.filter_parameters(
90+
out,
91+
effects = "all",
92+
component = component,
93+
flatten = flatten,
94+
recursive = FALSE
95+
)
96+
}
97+
98+
6699
#' @export
67100
find_parameters.betareg <- function(x, component = "all", flatten = FALSE, ...) {
68101
component <- validate_argument(

R/find_predictors.R

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ find_predictors.default <- function(x,
187187
if (flatten) {
188188
unique(unlist(l, use.names = FALSE))
189189
} else {
190-
l
190+
compact_list(l)
191191
}
192192
}
193193

@@ -295,7 +295,7 @@ find_predictors.afex_aov <- function(x,
295295
if (flatten) {
296296
unique(unlist(l, use.names = FALSE))
297297
} else {
298-
l
298+
compact_list(l)
299299
}
300300
}
301301

@@ -364,6 +364,44 @@ find_predictors.brmsfit <- function(x,
364364
}
365365

366366

367+
#' @export
368+
find_predictors.sdmTMB <- function(x,
369+
effects = "fixed",
370+
flatten = FALSE,
371+
verbose = TRUE,
372+
...) {
373+
effects <- validate_argument(effects, c("fixed", "random", "all"))
374+
elements <- .get_elements(effects, component = "conditional", model = x)
375+
376+
f <- find_formula(x, verbose = verbose)
377+
f <- .prepare_predictors(x, f, elements)
378+
379+
# random effects are returned as list, so we need to unlist here
380+
l <- .return_vars(f, x)
381+
382+
if (is_empty_object(l) || is_empty_object(compact_list(l))) {
383+
return(NULL)
384+
}
385+
386+
# add time variable
387+
l$time <- x$call$time
388+
389+
# add random slope, if not yet present
390+
if (object_has_names(l, "random") && effects == "all") {
391+
random_slope <- unlist(find_random_slopes(x), use.names = FALSE)
392+
all_predictors <- unlist(unique(l), use.names = FALSE)
393+
rs_not_in_pred <- unique(setdiff(random_slope, all_predictors))
394+
if (length(rs_not_in_pred)) l$random <- c(rs_not_in_pred, l$random)
395+
}
396+
397+
if (flatten) {
398+
unique(unlist(l, use.names = FALSE))
399+
} else {
400+
compact_list(l)
401+
}
402+
}
403+
404+
367405
#' @export
368406
find_predictors.insight_formula <- function(x, flatten = FALSE, verbose = TRUE, ...) {
369407
is_mv <- is_multivariate(x)

R/find_statistic.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ find_statistic.default <- function(x, ...) {
140140
"qr", "QRNLMM", "QRLMM",
141141
"Rchoice", "riskRegression", "robmixglm", "rma", "rma.mv", "rma.uni", "rrvglm",
142142
"Sarlm", "sem", "SemiParBIV", "serp", "slm", "slopes", "survreg", "svy_vglm",
143+
"sdmTMB",
143144
"test_mediation", "tobit",
144145
"vglm",
145146
"wbgee",

R/get_datagrid.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@
124124
#' length as indicated in `length` are generated. For numeric predictors not
125125
#' specified at first in `by`, mean and -1/+1 SD around the mean are
126126
#' returned. For factors, all levels are returned.
127+
#' - `"pretty"` will create a range "pretty" values, using [`pretty()`], where
128+
#' the value in `length` is used for the `n` argument in `pretty()`.
127129
#'
128130
#' `range` can also be a vector of different values (see 'Examples'). In this
129131
#' case, `range` must be of same length as numeric target variables. If
@@ -1181,7 +1183,7 @@ get_datagrid.comparisons <- get_datagrid.slopes
11811183
...) {
11821184
range <- validate_argument(
11831185
tolower(range),
1184-
c("range", "iqr", "ci", "hdi", "eti", "sd", "mad", "grid")
1186+
c("range", "iqr", "ci", "hdi", "eti", "sd", "mad", "grid", "pretty")
11851187
)
11861188

11871189
# bayestestR only for some options
@@ -1239,6 +1241,11 @@ get_datagrid.comparisons <- get_datagrid.slopes
12391241
return(out)
12401242
}
12411243

1244+
# If Range is a range of pretty values
1245+
if (range == "pretty") {
1246+
return(pretty(x, n = length))
1247+
}
1248+
12421249
# If Range is an interval
12431250
if (range == "iqr") { # nolint
12441251
mini <- stats::quantile(x, (1 - ci) / 2, ...)

R/get_family.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,26 @@ get_family.model_fit <- function(x, ...) {
4747
get_family(x$fit, ...)
4848
}
4949

50+
#' @export
51+
get_family.sdmTMB <- function(x, ...) {
52+
check_if_installed("sdmTMB")
53+
f <- x$family
54+
if (length(f$family) > 1) {
55+
f <- compact_list(list(
56+
family = f$family[2],
57+
link = f$link[2],
58+
linkfun = f[[2]]$linkfun,
59+
linkinv = f[[2]]$linkinv,
60+
mu.eta = f[[2]]$mu.eta,
61+
valideta = f[[2]]$valideta,
62+
name = f[[2]]$name,
63+
initialize = f[[2]]$initialize
64+
))
65+
class(f) <- "family"
66+
}
67+
f
68+
}
69+
5070

5171
.get_family <- function(x, ...) {
5272
info <- model_info(x, response = 1, verbose = FALSE)

0 commit comments

Comments
 (0)