Skip to content

Commit 56d2e87

Browse files
group = model ~ term
1 parent ecf425f commit 56d2e87

File tree

5 files changed

+135
-60
lines changed

5 files changed

+135
-60
lines changed

R/modelsummary.R

+64-36
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ globalVariables(c('.', 'term', 'part', 'estimate', 'conf.high', 'conf.low',
7070
#' * list of lists, each of which includes 3 elements named "raw", "clean", "fmt". Unknown statistics are omitted. See the 'Examples section below'.
7171
#' @param gof_omit string regular expression. Omits all matching gof statistics from
7272
#' the table (using `grepl(perl=TRUE)`).
73-
#' @param group a two-sided formula with three components: "term", "model", and
74-
#' a parameter group identifier (e.g., outcome levels of a multinomial logit
75-
#' model). Example: `term+groupid~model` The group identifier must be the name
76-
#' of a column in the data.frame produced by `get_estimates(model)`. The
77-
#' "term" component must be on the left-hand side of the formula.
73+
#' @param group a two-sided formula with two or three components which describes
74+
#' how groups of parameters should be displayed. The formula msut include both
75+
#' a "term" and a "model" component. In addition, a component can be used to
76+
#' identify groups of parameters (e.g., outcome levels of a multinomial logit
77+
#' model). This group identifier must be the name of a column in the
78+
#' data.frame produced by `get_estimates(model)`.
79+
#' * `term ~ model` displays coefficients as rows and models as columns
80+
#' * `model ~ term` displays models as rows and coefficients as columns
81+
#' * `response + term ~ model` displays response levels and coefficients as rows and models as columns.
7882
#' @param group_map named or unnamed character vector. Subset, rename, and
7983
#' reorder coefficient groups specified in the `group` argument. See `coef_map`.
8084
#' @param add_rows a data.frame (or tibble) with the same number of columns as
@@ -274,7 +278,7 @@ modelsummary <- function(
274278
coef_rename = NULL,
275279
gof_map = NULL,
276280
gof_omit = NULL,
277-
group = NULL,
281+
group = term ~ model,
278282
group_map = NULL,
279283
add_rows = NULL,
280284
align = NULL,
@@ -368,7 +372,6 @@ modelsummary <- function(
368372

369373
}
370374

371-
372375
term_order <- unique(unlist(lapply(est, function(x) x$term)))
373376
group_order <- unique(unlist(lapply(est, function(x) x$group)))
374377

@@ -386,34 +389,43 @@ modelsummary <- function(
386389
est[is.na(est)] <- ""
387390

388391
# sort rows using factor trick
389-
if (!is.null(coef_map)) {
390-
term_order <- coef_map
391-
est$term <- factor(est$term, unique(term_order))
392-
} else {
393-
est$term <- factor(est$term, unique(term_order))
394-
}
392+
if ("term" %in% colnames(est)) {
393+
if (!is.null(coef_map)) {
394+
term_order <- coef_map
395+
est$term <- factor(est$term, unique(term_order))
396+
} else {
397+
est$term <- factor(est$term, unique(term_order))
398+
}
395399

396-
if (!is.null(group_map)) {
397-
group_order <- group_map
398-
est$group <- factor(est$term, group_order)
399-
} else {
400-
est$group <- factor(est$group, unique(est$group))
400+
if (!is.null(group_map)) {
401+
group_order <- group_map
402+
est$group <- factor(est$term, group_order)
403+
} else {
404+
est$group <- factor(est$group, unique(est$group))
405+
}
406+
407+
} else if ("model" %in% colnames(est)) {
408+
est$model <- factor(est$model, model_names)
401409
}
402410

403411
est <- est[do.call(order, as.list(est)), ]
404412

405413
# character for binding
406-
est$term <- as.character(est$term)
407-
est$group <- as.character(est$group)
414+
for (col in c("term", "group", "model")) {
415+
if (col %in% colnames(est)) {
416+
est[[col]] <- as.character(est[[col]])
417+
}
418+
}
408419

409420
# group duplicates
410-
idx <- paste(as.character(est$term), est$statistic)
411-
if (is.null(group) && anyDuplicated(idx) > 0) {
412-
warning('The table includes duplicate term names. This can sometimes happen when a model produces "grouped" terms, such as in a multinomial logit or a gamlss model. Consider using the the `group` argument.')
421+
if ("term" %in% colnames(est)) {
422+
idx <- paste(as.character(est$term), est$statistic)
423+
if (is.null(group) && anyDuplicated(idx) > 0) {
424+
warning('The table includes duplicate term names. This can sometimes happen when a model produces "grouped" terms, such as in a multinomial logit or a gamlss model. Consider using the the `group` argument.')
425+
}
413426
}
414427

415428

416-
417429
#####################
418430
# goodness-of-fit #
419431
#####################
@@ -440,7 +452,6 @@ modelsummary <- function(
440452
}
441453

442454

443-
444455
##################
445456
# output table #
446457
##################
@@ -449,11 +460,11 @@ modelsummary <- function(
449460
tab[is.na(tab)] <- ''
450461

451462
# interaction : becomes ×
452-
if (is.null(coef_map)) {
453-
if (output_format != 'rtf') {
454-
idx <- tab$part != 'gof'
455-
tab$term <- ifelse(idx, gsub(':', ' \u00d7 ', tab$term), tab$term)
456-
}
463+
if (is.null(coef_map) &&
464+
"term" %in% colnames(tab) &&
465+
output_format != 'rtf') {
466+
idx <- tab$part != 'gof'
467+
tab$term <- ifelse(idx, gsub(':', ' \u00d7 ', tab$term), tab$term)
457468
}
458469

459470
# measure table
@@ -494,8 +505,8 @@ modelsummary <- function(
494505
}
495506

496507
# only show group label if it is a row-property (lhs of the group formula)
497-
if (is.null(group) ||
498-
group$group_name %in% group$rhs) {
508+
tmp <- setdiff(group$lhs, c("model", "term"))
509+
if (length(tmp) == 0) {
499510
tab$group <- NULL
500511
} else if (output_format != "dataframe") {
501512
colnames(tab)[colnames(tab) == "group"] <- " "
@@ -593,6 +604,7 @@ map_omit_gof <- function(gof, gof_omit, gof_map) {
593604

594605
# row identifier
595606
gof$part <- "gof"
607+
596608
gof <- gof[, unique(c("part", "term", names(gof)))]
597609

598610
# omit
@@ -634,12 +646,28 @@ map_omit_gof <- function(gof, gof_omit, gof_map) {
634646
#' @noRd
635647
group_reshape <- function(estimates, lhs, rhs, group_name) {
636648

637-
if (is.null(lhs)) return(estimates)
638-
639649
lhs[lhs == group_name] <- "group"
640650
rhs[rhs == group_name] <- "group"
641651

642-
if (all(c("term", "group") %in% lhs)) {
652+
# term ~ model (standard)
653+
if (is.null(lhs) ||
654+
(length(lhs) == 1 && lhs == "term" &&
655+
length(rhs) == 1 && rhs == "model")) {
656+
return(estimates)
657+
658+
# model ~ term
659+
} else if (length(lhs) == 1 && lhs == "model" &&
660+
length(rhs) == 1 && rhs == "term") {
661+
out <- tidyr::pivot_longer(estimates,
662+
cols = -c("group", "term", "statistic"),
663+
names_to = "model")
664+
out <- tidyr::pivot_wider(out, names_from = "term")
665+
666+
# order matters for sorting
667+
out <- out[, unique(c("group", "model", "statistic", colnames(out)))]
668+
669+
# term + group ~ model
670+
} else if (all(c("term", "group") %in% lhs)) {
643671
idx <- unique(c(lhs, colnames(estimates)))
644672
out <- estimates[, idx, drop = FALSE]
645673

@@ -660,7 +688,7 @@ group_reshape <- function(estimates, lhs, rhs, group_name) {
660688
} else if (all(c("group", "model") %in% rhs)) {
661689
out <- estimates
662690
out <- tidyr::pivot_longer(out,
663-
cols = !any_of(c("part", "group", "term", "statistic")),
691+
cols = !tidyselect::any_of(c("part", "group", "term", "statistic")),
664692
names_to = "model")
665693
out$idx_col <- paste(out[[rhs[1]]], "/", out[[rhs[2]]])
666694
out$model <- out$group <- NULL

R/sanitize_group.R

+36-12
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,45 @@
33
#' @noRd
44
sanitize_group <- function(group) {
55

6-
if (is.null(group)) return(group)
6+
flag_error <- FALSE
77

8-
checkmate::assert_formula(group, null.ok = TRUE)
8+
checkmate::assert_formula(group)
99

10-
rhs <- all.vars(stats::update(group, "0 ~ ."))
11-
lhs <- all.vars(stats::update(group, ". ~ 0"))
10+
rhs <- all.vars(stats::update(group, "0 ~ ."))
11+
lhs <- all.vars(stats::update(group, ". ~ 0"))
12+
variables <- c(rhs, lhs)
13+
14+
if (!all(c("model", "term") %in% c(lhs, rhs))) {
15+
flag_error <- TRUE
16+
}
17+
18+
if (length(variables) != length(unique(variables))) {
19+
flag_error <- TRUE
20+
}
21+
22+
if (length(variables) > 3) {
23+
flag_error <- TRUE
24+
} else if (length(variables) == 2) {
25+
group_name <- NULL
26+
} else {
1227
group_name <- setdiff(c(lhs, rhs), c("term", "model"))
28+
}
29+
30+
if (flag_error == TRUE) {
31+
stop('The `group` argument must be a two-sided formula with two or three components. The formula must include a component named "term", which represents the parameters of the model. The formula must include a component named "model", which represents the different models being summarized. For example,
32+
33+
model ~ term
34+
35+
displays models as rows and parameter estimates as columns. Inverting the formula would display models as columns and terms as rows.
36+
37+
The formula can also include a third, optional, component: a group identifier. In contrast to the "term" and "model" components, the name of the group identifier is not fixed. It must correspond to the name of a column in the data.frame produced by `get_estimates(model)`. For example, applying the `get_estimates` function to a multinomial logit model returns a column called "response", which identifies the parameters that correspond to each value of the responde variable:
38+
39+
model + response ~ term')
40+
}
1341

14-
if (!all(c("term", "model") %in% c(lhs, rhs)) ||
15-
length(unique(c(lhs, rhs))) != 3) {
16-
stop('The `group` argument must be a two-sided formula with three components: "term", "model", and a group identifier. The group identifier must be the name of a column in the data.frame produced by `get_estimates(model)`. The "term" component must be on the left-hand side of the formula. ')
17-
}
42+
out <- list("lhs" = lhs,
43+
"rhs" = rhs,
44+
"group_name" = group_name)
1845

19-
out <- list("lhs" = lhs,
20-
"rhs" = rhs,
21-
"group_name" = group_name)
22-
return(out)
46+
return(out)
2347
}

man/modelsummary.Rd

+12-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/msummary.Rd

+12-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-group.R

+11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ models[['Multivariate']] <- lm(Girth ~ Height + Volume, data = trees)
99
models[["GAMLSS"]] <- gamlss(y~pb(x),sigma.fo=~pb(x),family=BCT, data=abdom, method=mixed(1,20), trace=FALSE)
1010

1111

12+
test_that("flipped table (no groups)", {
13+
mod = list(
14+
lm(hp ~ mpg, mtcars),
15+
lm(hp ~ mpg + drat, mtcars))
16+
tab = modelsummary(mod,
17+
output = "data.frame",
18+
group = model ~ term)
19+
expect_true("model" %in% colnames(tab))
20+
})
21+
22+
1223
test_that("group: nnet::multinom", {
1324
skip_if_not_installed("nnet")
1425

0 commit comments

Comments
 (0)