Skip to content

Commit 56c8c54

Browse files
committed
support classification for mixture models
1 parent 2f2bda2 commit 56c8c54

4 files changed

Lines changed: 45 additions & 14 deletions

File tree

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
* `get_predicted()` now supports models of class `glmtoolbox::glmee`.
99

10+
* `get_predicted()` supports predicting the class membership for models from
11+
package *brms* with `mixture()` family, using `predict = "classificaton"`.
12+
13+
* `model_info()` returns `$is_mixture` to identify finite mixture models.
14+
1015
* Better support for models of class `sdmTMB`.
1116

1217
## Bug fixes

R/get_predicted.R

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@
3535
#' * `"prediction"` also gives an output on the response scale, but this time
3636
#' associated with a prediction interval (PI), which is larger than a confidence
3737
#' interval (though it mostly make sense for linear models).
38-
#' * `"classification"` only differs from `"prediction"` for binomial models
39-
#' where it additionally transforms the predictions into the original response's
40-
#' type (for instance, to a factor).
38+
#' * `"classification"` is releveant only for binomial, ordinal or mixture models.
39+
#' - For binomial models, `predict = "classification"` will additionally
40+
#' transform the predictions into the original response's type (for
41+
#' instance, to a factor).
42+
#' - For ordinal models (e.g., classes `clm` or `multinom`), gives the
43+
#' predicted response class membership, defined as highest probability
44+
#' prediction.
45+
#' - For finite mixture models (currently only family [`brms::mixture()`] from
46+
#' package *brms*), returns a vector of predicted class membership (similar
47+
#' as for ordinal models).
4148
#' * Other strings are passed directly to the `type` argument of the `predict()`
4249
#' method supplied by the modelling package.
4350
#' * Specifically for models of class `brmsfit` (package *brms*), the `predict`
@@ -49,7 +56,7 @@
4956
#' by the modelling package. Note that this might result in conflicts with
5057
#' multiple matching `type` arguments - thus, the recommendation is to use the
5158
#' `predict` argument for those values.
52-
#' * Notes: You can see the 4 options for predictions as on a gradient from
59+
#' * Notes: You can see the four options for predictions as on a gradient from
5360
#' "close to the model" to "close to the response data": "link", "expectation",
5461
#' "prediction", "classification". The `predict` argument modulates two things:
5562
#' the scale of the output and the type of certainty interval. Read more about
@@ -138,10 +145,12 @@
138145
#' and no transformation is applied. For instance, for a logistic regression
139146
#' model, the response scale corresponds to the predicted probabilities, whereas
140147
#' the link-scale makes predictions of log-odds (probabilities on the logit
141-
#' scale). Note that when users select `predict="classification"` in binomial
148+
#' scale). Note that when users select `predict = "classification"` in binomial
142149
#' models, the `get_predicted()` function will first calculate predictions as if
143-
#' the user had selected `predict="expectation"`. Then, it will round the
144-
#' responses in order to return the most likely outcome.
150+
#' the user had selected `predict = "expectation"`. Then, it will round the
151+
#' responses in order to return the most likely outcome. For ordinal or mixture
152+
#' models, it returns the predicted class membership, based on the highest
153+
#' probability of classification.
145154
#'
146155
#' @section Heteroscedasticity consistent standard errors: The arguments `vcov`
147156
#' and `vcov_args` can be used to calculate robust standard errors for

R/get_predicted_bayesian.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ get_predicted.stanreg <- function(x,
8383
# exceptions
8484
is_wiener <- inherits(model_family, "brmsfamily") && model_family$family == "wiener"
8585
is_rtchoice <- model_family$family == "custom" && model_family$name == "lnr"
86+
is_mixture <- model_family$family == "mixture"
8687

8788
# Special case for rwiener (get choice 1 as negative values)
8889
# Note that for mv models, x$family returns a list of families
@@ -119,6 +120,11 @@ get_predicted.stanreg <- function(x,
119120
dim = c(dim(draws), 2),
120121
dimnames = list(NULL, NULL, c("rt", "response"))
121122
)
123+
} else if (is_mixture && identical(my_args$predict, "classification")) {
124+
# for mixture models, which predict the class membership, we stop
125+
# here and just return the predicted class membership
126+
mixture_output <- do.call(brms::pp_mixture, fun_args)
127+
return(apply(mixture_output[, 1, ], 1, which.max))
122128
}
123129
}
124130

man/get_predicted.Rd

Lines changed: 18 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)