Skip to content

Commit

Permalink
feat: add decision boundary plots
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 19, 2023
1 parent eafe2f4 commit 4adb798
Show file tree
Hide file tree
Showing 31 changed files with 3,324 additions and 803 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ Config/testthat/parallel: true
Encoding: UTF-8
NeedsCompilation: no
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'BenchmarkResult.R'
'Filter.R'
'LearnerClassif.R'
'LearnerClassifCVGlmnet.R'
'LearnerClassifGlmnet.R'
'LearnerClassifRpart.R'
'LearnerClustHierarchical.R'
'LearnerRegr.R'
'LearnerRegrCVGlmnet.R'
'LearnerRegrGlmnet.R'
'LearnerRegrRpart.R'
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ S3method(as_precrec,PredictionClassif)
S3method(as_precrec,ResampleResult)
S3method(autoplot,BenchmarkResult)
S3method(autoplot,Filter)
S3method(autoplot,LearnerClassif)
S3method(autoplot,LearnerClassifCVGlmnet)
S3method(autoplot,LearnerClassifGlmnet)
S3method(autoplot,LearnerClassifRpart)
S3method(autoplot,LearnerClustAgnes)
S3method(autoplot,LearnerClustDiana)
S3method(autoplot,LearnerClustHclust)
S3method(autoplot,LearnerClustHierarchical)
S3method(autoplot,LearnerRegr)
S3method(autoplot,LearnerRegrCVGlmnet)
S3method(autoplot,LearnerRegrGlmnet)
S3method(autoplot,LearnerRegrRpart)
Expand Down Expand Up @@ -52,6 +54,8 @@ import(checkmate)
import(data.table)
import(ggplot2)
import(mlr3misc)
importFrom(ggplot2,autoplot)
importFrom(ggplot2,fortify)
importFrom(graphics,plot)
importFrom(scales,pretty_breaks)
importFrom(stats,as.dendrogram)
Expand Down
87 changes: 87 additions & 0 deletions R/LearnerClassif.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#' @title Plot for Classification Learners
#'
#' @description
#' Visualizations for [mlr3::LearnerClassif].
#' The argument `type` controls what kind of plot is drawn.
#' Possible choices are:
#'
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
#'
#' @param object ([mlr3::LearnerClassif]).
#'
#' @template param_type
#' @template param_task
#' @template param_grid_points
#' @template param_expand_range
#' @template param_theme
#' @param ... (ignored).
#'
#' @return [ggplot2::ggplot()].
#'
#' @export
#' @examples
#' \donttest{
#' if (requireNamespace("mlr3")) {
#' library(mlr3)
#' library(mlr3viz)
#'
#' task = tsk("pima")$select(c("age", "pedigree"))
#' learner = lrn("classif.rpart", predict_type = "prob")
#' learner$train(task)
#'
#' autoplot(learner, type = "prediction", task)
#' }
#' }
autoplot.LearnerClassif = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_string(type)

switch(type,
"prediction" = {
mlr3::assert_task(task)
features = task$feature_names

if (length(features) != 2L) {
mlr3misc::stopf("Plot learner prediction only works for tasks with two features for classification!", wrap = TRUE)
}

grid = predict_grid(list(object), task, grid_points = 100L, expand_range = 0)

if (object$predict_type == "prob") {
# classif, probs
raster_aes = aes(
fill = .data[["response"]],
alpha = .data[[".prob.response"]])
scale_alpha = scale_alpha_continuous(
name = "Probability",
guide = guide_legend(override.aes = list(fill = viridis::viridis(1))))
scale_fill = scale_fill_viridis_d(end = 0.8)
guides = NULL
} else if (object$predict_type == "response") {
# classif, no probs
raster_aes = aes(fill = .data[["response"]])
scale_alpha = NULL
scale_fill = scale_fill_viridis_d(end = 0.8)
guides = NULL
}

ggplot(grid,
mapping = aes(
x = .data[[features[1L]]],
y = .data[[features[2L]]])) +
geom_raster(raster_aes) +
geom_point(
mapping = aes(fill = .data[[task$target_names]]),
data = task$data(),
shape = 21,
color = "black") +
scale_fill +
guides +
theme +
theme(legend.position = "right") +
scale_alpha +
labs(fill = "Response")
},

stopf("Unknown plot type '%s'", type)
)
}
17 changes: 13 additions & 4 deletions R/LearnerClassifCVGlmnet.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
#' @rdname autoplot.LearnerClassifGlmnet
#' @export
autoplot.LearnerClassifCVGlmnet = function(object, theme = theme_minimal(), ...) { # nolint
plot_ggfortify(object, ...) +
scale_color_viridis_d("Feature") +
theme
autoplot.LearnerClassifCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
switch(type,
"prediction" = {
NextMethod()
},

"ggfortify" = {
plot_ggfortify(object, ...) +
scale_color_viridis_d("Feature") +
theme
},

stopf("Unknown plot type '%s'", type)
)
}

#' @export
Expand Down
42 changes: 32 additions & 10 deletions R/LearnerClassifGlmnet.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
#' @title Plots for GLMNet Learners
#'
#' @description
#' Visualizations for GLMNet learners using the package \CRANpkg{ggfortify}.
#' Visualizations for [mlr3learners::LearnerClassifGlmnet].
#' The argument `type` controls what kind of plot is drawn.
#' Possible choices are:
#'
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
#' * `"ggfortify"`: Visualizes the model using the package \CRANpkg{ggfortify}.
#'
#' @param object ([mlr3learners::LearnerClassifGlmnet] | [mlr3learners::LearnerRegrGlmnet] | [mlr3learners::LearnerRegrCVGlmnet] | [mlr3learners::LearnerRegrCVGlmnet]).
#'
#' @template param_type
#' @template param_task
#' @template param_grid_points
#' @template param_expand_range
#' @template param_theme
#' @param ... (ignored).
#'
Expand All @@ -23,22 +33,34 @@
#' task = tsk("sonar")
#' learner = lrn("classif.glmnet")
#' learner$train(task)
#' autoplot(learner)
#' autoplot(learner, type = "ggfortify")
#'
#' # regression
#' task = tsk("mtcars")
#' learner = lrn("regr.glmnet")
#' learner$train(task)
#' autoplot(learner)
#' autoplot(learner, type = "ggfortify")
#' }
autoplot.LearnerClassifGlmnet = function(object, theme = theme_minimal(), ...) { # nolint
if ("twoclass" %nin% object$state$train_task$properties) {
stopf("Plot of %s only works with binary classification tasks.", object$id)
}
autoplot.LearnerClassifGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_has_model(object)

switch(type,
"prediction" = {
NextMethod()
},

"ggfortify" = {
if ("twoclass" %nin% object$state$train_task$properties) {
stopf("Plot of %s only works with binary classification tasks.", object$id)
}

plot_ggfortify(object) +
scale_color_viridis_d("Feature") +
theme
},

plot_ggfortify(object) +
scale_color_viridis_d("Feature") +
theme
stopf("Unknown plot type '%s'", type)
)
}

#' @export
Expand Down
75 changes: 45 additions & 30 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
#' @title Plots for Rpart Learners
#'
#' @description
#' Visualizations rpart trees using the package \CRANpkg{ggparty}.
#' Visualizations for [mlr3::LearnerClassifRpart].
#' The argument `type` controls what kind of plot is drawn.
#' Possible choices are:
#'
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
#' * `"ggparty"`: Visualizes the tree using the package \CRANpkg{ggparty}.
#'
#' @param object ([mlr3::LearnerClassifRpart] | [mlr3::LearnerRegrRpart]).
#'
#' @template param_type
#' @template param_task
#' @template param_grid_points
#' @template param_expand_range
#' @template param_theme
#' @param ... (ignored).
#'
Expand All @@ -19,45 +29,50 @@
#' task = tsk("iris")
#' learner = lrn("classif.rpart", keep_model = TRUE)
#' learner$train(task)
#' autoplot(learner)
#' autoplot(learner, type = "ggparty")
#'
#' # regression
#' task = tsk("mtcars")
#' learner = lrn("regr.rpart", keep_model = TRUE)
#' learner$train(task)
#' autoplot(learner)
#' autoplot(learner, type = "ggparty")
#' }
autoplot.LearnerClassifRpart = function(object, theme = theme_minimal(), ...) { # nolint
autoplot.LearnerClassifRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
assert_has_model(object)

if (is.null(object$model$model)) {
stopf("Learner '%s' must be trained with `keep_model` set to `TRUE`", object$id)
}
switch(type,
"prediction" = {
NextMethod()
},

"ggparty" = {
require_namespaces(c("partykit", "ggparty"))
target = all.vars(object$model$terms)[1L]

require_namespaces(c("partykit", "ggparty"))
target = all.vars(object$model$terms)[1L]
ggparty::ggparty(partykit::as.party(object$model)) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_splitvar() +
ggparty::geom_node_plot(
gglist = list(
geom_bar(aes(x = "", fill = .data[[target]]),
alpha = 0.8,
color = "#000000",
linewidth = 0.5,
position = position_fill()),
xlab(target),
scale_fill_viridis_d(end = 0.8),
theme),
ids = "terminal",
shared_axis_labels= TRUE) +
ggparty::geom_node_label(
mapping = aes(label = paste0("n=", .data[["nodesize"]])),
nudge_y = 0.03,
ids = "terminal")
},

ggparty::ggparty(partykit::as.party(object$model)) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_splitvar() +
ggparty::geom_node_plot(
gglist = list(
geom_bar(aes(x = "", fill = .data[[target]]),
alpha = 0.8,
color = "#000000",
linewidth = 0.5,
position = position_fill()),
xlab(target),
scale_fill_viridis_d(end = 0.8),
theme),
ids = "terminal",
shared_axis_labels= TRUE) +
ggparty::geom_node_label(
mapping = aes(label = paste0("n=", .data[["nodesize"]])),
nudge_y = 0.03,
ids = "terminal"
)
stopf("Unknown plot type '%s'", type)
)
}

#' @export
Expand Down
Loading

0 comments on commit 4adb798

Please sign in to comment.