Skip to content

Commit 4adb798

Browse files
committed
feat: add decision boundary plots
1 parent eafe2f4 commit 4adb798

31 files changed

+3324
-803
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ Config/testthat/parallel: true
6363
Encoding: UTF-8
6464
NeedsCompilation: no
6565
Roxygen: list(markdown = TRUE)
66-
RoxygenNote: 7.2.3
66+
RoxygenNote: 7.2.3.9000
6767
Collate:
6868
'BenchmarkResult.R'
6969
'Filter.R'
70+
'LearnerClassif.R'
7071
'LearnerClassifCVGlmnet.R'
7172
'LearnerClassifGlmnet.R'
7273
'LearnerClassifRpart.R'
7374
'LearnerClustHierarchical.R'
75+
'LearnerRegr.R'
7476
'LearnerRegrCVGlmnet.R'
7577
'LearnerRegrGlmnet.R'
7678
'LearnerRegrRpart.R'

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ S3method(as_precrec,PredictionClassif)
55
S3method(as_precrec,ResampleResult)
66
S3method(autoplot,BenchmarkResult)
77
S3method(autoplot,Filter)
8+
S3method(autoplot,LearnerClassif)
89
S3method(autoplot,LearnerClassifCVGlmnet)
910
S3method(autoplot,LearnerClassifGlmnet)
1011
S3method(autoplot,LearnerClassifRpart)
1112
S3method(autoplot,LearnerClustAgnes)
1213
S3method(autoplot,LearnerClustDiana)
1314
S3method(autoplot,LearnerClustHclust)
1415
S3method(autoplot,LearnerClustHierarchical)
16+
S3method(autoplot,LearnerRegr)
1517
S3method(autoplot,LearnerRegrCVGlmnet)
1618
S3method(autoplot,LearnerRegrGlmnet)
1719
S3method(autoplot,LearnerRegrRpart)
@@ -52,6 +54,8 @@ import(checkmate)
5254
import(data.table)
5355
import(ggplot2)
5456
import(mlr3misc)
57+
importFrom(ggplot2,autoplot)
58+
importFrom(ggplot2,fortify)
5559
importFrom(graphics,plot)
5660
importFrom(scales,pretty_breaks)
5761
importFrom(stats,as.dendrogram)

R/LearnerClassif.R

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#' @title Plot for Classification Learners
2+
#'
3+
#' @description
4+
#' Visualizations for [mlr3::LearnerClassif].
5+
#' The argument `type` controls what kind of plot is drawn.
6+
#' Possible choices are:
7+
#'
8+
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
9+
#'
10+
#' @param object ([mlr3::LearnerClassif]).
11+
#'
12+
#' @template param_type
13+
#' @template param_task
14+
#' @template param_grid_points
15+
#' @template param_expand_range
16+
#' @template param_theme
17+
#' @param ... (ignored).
18+
#'
19+
#' @return [ggplot2::ggplot()].
20+
#'
21+
#' @export
22+
#' @examples
23+
#' \donttest{
24+
#' if (requireNamespace("mlr3")) {
25+
#' library(mlr3)
26+
#' library(mlr3viz)
27+
#'
28+
#' task = tsk("pima")$select(c("age", "pedigree"))
29+
#' learner = lrn("classif.rpart", predict_type = "prob")
30+
#' learner$train(task)
31+
#'
32+
#' autoplot(learner, type = "prediction", task)
33+
#' }
34+
#' }
35+
autoplot.LearnerClassif = function(object, type = "prediction", task, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
36+
assert_string(type)
37+
38+
switch(type,
39+
"prediction" = {
40+
mlr3::assert_task(task)
41+
features = task$feature_names
42+
43+
if (length(features) != 2L) {
44+
mlr3misc::stopf("Plot learner prediction only works for tasks with two features for classification!", wrap = TRUE)
45+
}
46+
47+
grid = predict_grid(list(object), task, grid_points = 100L, expand_range = 0)
48+
49+
if (object$predict_type == "prob") {
50+
# classif, probs
51+
raster_aes = aes(
52+
fill = .data[["response"]],
53+
alpha = .data[[".prob.response"]])
54+
scale_alpha = scale_alpha_continuous(
55+
name = "Probability",
56+
guide = guide_legend(override.aes = list(fill = viridis::viridis(1))))
57+
scale_fill = scale_fill_viridis_d(end = 0.8)
58+
guides = NULL
59+
} else if (object$predict_type == "response") {
60+
# classif, no probs
61+
raster_aes = aes(fill = .data[["response"]])
62+
scale_alpha = NULL
63+
scale_fill = scale_fill_viridis_d(end = 0.8)
64+
guides = NULL
65+
}
66+
67+
ggplot(grid,
68+
mapping = aes(
69+
x = .data[[features[1L]]],
70+
y = .data[[features[2L]]])) +
71+
geom_raster(raster_aes) +
72+
geom_point(
73+
mapping = aes(fill = .data[[task$target_names]]),
74+
data = task$data(),
75+
shape = 21,
76+
color = "black") +
77+
scale_fill +
78+
guides +
79+
theme +
80+
theme(legend.position = "right") +
81+
scale_alpha +
82+
labs(fill = "Response")
83+
},
84+
85+
stopf("Unknown plot type '%s'", type)
86+
)
87+
}

R/LearnerClassifCVGlmnet.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
#' @rdname autoplot.LearnerClassifGlmnet
22
#' @export
3-
autoplot.LearnerClassifCVGlmnet = function(object, theme = theme_minimal(), ...) { # nolint
4-
plot_ggfortify(object, ...) +
5-
scale_color_viridis_d("Feature") +
6-
theme
3+
autoplot.LearnerClassifCVGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
4+
switch(type,
5+
"prediction" = {
6+
NextMethod()
7+
},
78

9+
"ggfortify" = {
10+
plot_ggfortify(object, ...) +
11+
scale_color_viridis_d("Feature") +
12+
theme
13+
},
14+
15+
stopf("Unknown plot type '%s'", type)
16+
)
817
}
918

1019
#' @export

R/LearnerClassifGlmnet.R

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#' @title Plots for GLMNet Learners
22
#'
33
#' @description
4-
#' Visualizations for GLMNet learners using the package \CRANpkg{ggfortify}.
4+
#' Visualizations for [mlr3learners::LearnerClassifGlmnet].
5+
#' The argument `type` controls what kind of plot is drawn.
6+
#' Possible choices are:
7+
#'
8+
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
9+
#' * `"ggfortify"`: Visualizes the model using the package \CRANpkg{ggfortify}.
510
#'
611
#' @param object ([mlr3learners::LearnerClassifGlmnet] | [mlr3learners::LearnerRegrGlmnet] | [mlr3learners::LearnerRegrCVGlmnet] | [mlr3learners::LearnerRegrCVGlmnet]).
12+
#'
13+
#' @template param_type
14+
#' @template param_task
15+
#' @template param_grid_points
16+
#' @template param_expand_range
717
#' @template param_theme
818
#' @param ... (ignored).
919
#'
@@ -23,22 +33,34 @@
2333
#' task = tsk("sonar")
2434
#' learner = lrn("classif.glmnet")
2535
#' learner$train(task)
26-
#' autoplot(learner)
36+
#' autoplot(learner, type = "ggfortify")
2737
#'
2838
#' # regression
2939
#' task = tsk("mtcars")
3040
#' learner = lrn("regr.glmnet")
3141
#' learner$train(task)
32-
#' autoplot(learner)
42+
#' autoplot(learner, type = "ggfortify")
3343
#' }
34-
autoplot.LearnerClassifGlmnet = function(object, theme = theme_minimal(), ...) { # nolint
35-
if ("twoclass" %nin% object$state$train_task$properties) {
36-
stopf("Plot of %s only works with binary classification tasks.", object$id)
37-
}
44+
autoplot.LearnerClassifGlmnet = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
45+
assert_has_model(object)
46+
47+
switch(type,
48+
"prediction" = {
49+
NextMethod()
50+
},
51+
52+
"ggfortify" = {
53+
if ("twoclass" %nin% object$state$train_task$properties) {
54+
stopf("Plot of %s only works with binary classification tasks.", object$id)
55+
}
56+
57+
plot_ggfortify(object) +
58+
scale_color_viridis_d("Feature") +
59+
theme
60+
},
3861

39-
plot_ggfortify(object) +
40-
scale_color_viridis_d("Feature") +
41-
theme
62+
stopf("Unknown plot type '%s'", type)
63+
)
4264
}
4365

4466
#' @export

R/LearnerClassifRpart.R

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#' @title Plots for Rpart Learners
22
#'
33
#' @description
4-
#' Visualizations rpart trees using the package \CRANpkg{ggparty}.
4+
#' Visualizations for [mlr3::LearnerClassifRpart].
5+
#' The argument `type` controls what kind of plot is drawn.
6+
#' Possible choices are:
7+
#'
8+
#' * `"prediction"` (default): Decision boundary of the learner and the true class labels.
9+
#' * `"ggparty"`: Visualizes the tree using the package \CRANpkg{ggparty}.
510
#'
611
#' @param object ([mlr3::LearnerClassifRpart] | [mlr3::LearnerRegrRpart]).
12+
#'
13+
#' @template param_type
14+
#' @template param_task
15+
#' @template param_grid_points
16+
#' @template param_expand_range
717
#' @template param_theme
818
#' @param ... (ignored).
919
#'
@@ -19,45 +29,50 @@
1929
#' task = tsk("iris")
2030
#' learner = lrn("classif.rpart", keep_model = TRUE)
2131
#' learner$train(task)
22-
#' autoplot(learner)
32+
#' autoplot(learner, type = "ggparty")
2333
#'
2434
#' # regression
2535
#' task = tsk("mtcars")
2636
#' learner = lrn("regr.rpart", keep_model = TRUE)
2737
#' learner$train(task)
28-
#' autoplot(learner)
38+
#' autoplot(learner, type = "ggparty")
2939
#' }
30-
autoplot.LearnerClassifRpart = function(object, theme = theme_minimal(), ...) { # nolint
40+
autoplot.LearnerClassifRpart = function(object, type = "prediction", task = NULL, grid_points = 100L, expand_range = 0, theme = theme_minimal(), ...) { # nolint
3141
assert_has_model(object)
3242

33-
if (is.null(object$model$model)) {
34-
stopf("Learner '%s' must be trained with `keep_model` set to `TRUE`", object$id)
35-
}
43+
switch(type,
44+
"prediction" = {
45+
NextMethod()
46+
},
47+
48+
"ggparty" = {
49+
require_namespaces(c("partykit", "ggparty"))
50+
target = all.vars(object$model$terms)[1L]
3651

37-
require_namespaces(c("partykit", "ggparty"))
38-
target = all.vars(object$model$terms)[1L]
52+
ggparty::ggparty(partykit::as.party(object$model)) +
53+
ggparty::geom_edge() +
54+
ggparty::geom_edge_label() +
55+
ggparty::geom_node_splitvar() +
56+
ggparty::geom_node_plot(
57+
gglist = list(
58+
geom_bar(aes(x = "", fill = .data[[target]]),
59+
alpha = 0.8,
60+
color = "#000000",
61+
linewidth = 0.5,
62+
position = position_fill()),
63+
xlab(target),
64+
scale_fill_viridis_d(end = 0.8),
65+
theme),
66+
ids = "terminal",
67+
shared_axis_labels= TRUE) +
68+
ggparty::geom_node_label(
69+
mapping = aes(label = paste0("n=", .data[["nodesize"]])),
70+
nudge_y = 0.03,
71+
ids = "terminal")
72+
},
3973

40-
ggparty::ggparty(partykit::as.party(object$model)) +
41-
ggparty::geom_edge() +
42-
ggparty::geom_edge_label() +
43-
ggparty::geom_node_splitvar() +
44-
ggparty::geom_node_plot(
45-
gglist = list(
46-
geom_bar(aes(x = "", fill = .data[[target]]),
47-
alpha = 0.8,
48-
color = "#000000",
49-
linewidth = 0.5,
50-
position = position_fill()),
51-
xlab(target),
52-
scale_fill_viridis_d(end = 0.8),
53-
theme),
54-
ids = "terminal",
55-
shared_axis_labels= TRUE) +
56-
ggparty::geom_node_label(
57-
mapping = aes(label = paste0("n=", .data[["nodesize"]])),
58-
nudge_y = 0.03,
59-
ids = "terminal"
60-
)
74+
stopf("Unknown plot type '%s'", type)
75+
)
6176
}
6277

6378
#' @export

0 commit comments

Comments
 (0)