Skip to content

Commit 776f524

Browse files
committed
Merge branch 'main' of github.com:mlr-org/mlr3viz
2 parents 98d1651 + 572d3de commit 776f524

11 files changed

+166
-8
lines changed

DESCRIPTION

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Suggests:
4949
mlr3cluster,
5050
mlr3filters,
5151
mlr3fselect (>= 1.0.0),
52+
mlr3inference,
5253
mlr3learners,
5354
mlr3tuning (>= 1.0.0),
5455
paradox,
@@ -60,13 +61,20 @@ Suggests:
6061
stats,
6162
testthat (>= 3.0.0),
6263
vdiffr (>= 1.0.2),
63-
xgboost
64+
xgboost,
65+
survminer,
66+
mlr3proba (>= 0.6.3)
67+
Remotes:
68+
mlr-org/mlr3proba,
69+
mlr-org/mlr3inference
70+
Additional_repositories:
71+
https://mlr-org.r-universe.dev
6472
Config/testthat/edition: 3
6573
Config/testthat/parallel: true
6674
Encoding: UTF-8
6775
NeedsCompilation: no
6876
Roxygen: list(markdown = TRUE)
69-
RoxygenNote: 7.3.1
77+
RoxygenNote: 7.3.2
7078
Collate:
7179
'BenchmarkResult.R'
7280
'Filter.R'
@@ -79,6 +87,7 @@ Collate:
7987
'LearnerRegrCVGlmnet.R'
8088
'LearnerRegrGlmnet.R'
8189
'LearnerRegrRpart.R'
90+
'LearnerSurvCoxPH.R'
8291
'OptimInstanceBatchSingleCrit.R'
8392
'Prediction.R'
8493
'PredictionClassif.R'

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ S3method(autoplot,LearnerRegr)
1818
S3method(autoplot,LearnerRegrCVGlmnet)
1919
S3method(autoplot,LearnerRegrGlmnet)
2020
S3method(autoplot,LearnerRegrRpart)
21+
S3method(autoplot,LearnerSurvCoxPH)
2122
S3method(autoplot,OptimInstanceBatchSingleCrit)
2223
S3method(autoplot,PredictionClassif)
2324
S3method(autoplot,PredictionClust)
@@ -41,6 +42,7 @@ S3method(plot,LearnerClassifRpart)
4142
S3method(plot,LearnerRegrCVGlmnet)
4243
S3method(plot,LearnerRegrGlmnet)
4344
S3method(plot,LearnerRegrRpart)
45+
S3method(plot,LearnerSurvCoxPH)
4446
S3method(plot,PredictionClassif)
4547
S3method(plot,PredictionRegr)
4648
S3method(plot,ResampleResult)

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# mlr3viz (development version)
22

3+
- Add plot for `LearnerSurvCoxPH`.
4+
- Add plot for confidence intervals (`mlr3inference`)
5+
36
# mlr3viz 0.9.0
47

58
- Work with new bbotk 0.9.0 and mlr3tuning 0.21.0
6-
79
- Add plots for `EnsembleFSResult` object.
810

911
# mlr3viz 0.8.0

R/BenchmarkResult.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#' Requires package \CRANpkg{precrec}.
1313
#' * `"prc"`: Precision recall curve.
1414
#' See `"roc"`.
15+
#' * `"ci"`: Plot confidence intervals. Pass a `msr("ci", ...)` from the `mlr3inference` package as argument `measure`.
1516
#'
1617
#' @param object ([mlr3::BenchmarkResult]).
1718
#' @template param_type
@@ -45,6 +46,43 @@ autoplot.BenchmarkResult = function(object, type = "boxplot", measure = NULL, th
4546

4647
task = object$tasks$task[[1L]]
4748
measure = mlr3::assert_measure(mlr3::as_measure(measure, task_type = task$task_type), task = task)
49+
50+
if (identical(type, "ci")) {
51+
mlr3misc::require_namespaces("mlr3inference")
52+
53+
assert_class(measure, "MeasureAbstractCi")
54+
mid = measure$id
55+
56+
tbl = object$aggregate(measure)
57+
58+
tmp = map(object$resamplings$resampling, function(x) {
59+
list(class(x), x$param_set$values)
60+
})
61+
62+
if (length(unique(tmp)) != 1) {
63+
stopf("Plot of type 'ci' requires exactly one resampling method")
64+
}
65+
66+
# static checker
67+
.data = NULL
68+
task_id = NULL
69+
p = ggplot(tbl, aes(x = .data[["learner_id"]], y = .data[[mid]])) +
70+
geom_point() +
71+
geom_errorbar(aes(ymin = .data[[paste0(mid, ".lower")]], ymax = .data[[paste0(mid, ".upper")]]), width = 0.2) +
72+
facet_wrap(vars(task_id), scales = "free_y") +
73+
labs(
74+
title = sprintf("Confidence Intervals for alpha = %s", measure$param_set$values$alpha),
75+
x = "Learner",
76+
y = paste0(measure$measure$id)
77+
) +
78+
theme +
79+
theme(
80+
axis.text.x = element_text(angle = 45, hjust = 1),
81+
axis.title.x = element_blank()
82+
)
83+
return(p)
84+
}
85+
4886
measure_id = measure$id
4987
tab = fortify(object, measure = measure)
5088
tab$nr = sprintf("%09d", tab$nr)

R/LearnerSurvCoxPH.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#' @title Plots for Cox Proportional Hazards Learner
2+
#'
3+
#' @description
4+
#' Visualizations for [mlr3proba::LearnerSurvCoxPH].
5+
#'
6+
#' The argument `type` controls what kind of plot is drawn.
7+
#' The only possible choice right now is `"ggforest"` (default) which is a
8+
#' Forest Plot, using [ggforest][survminer::ggforest()].
9+
#' This plot displays the estimated hazard ratios (HRs) and their confidence
10+
#' intervals (CIs) for different variables included in the (trained) model.
11+
#'
12+
#' @param object ([mlr3proba::LearnerSurvCoxPH]).
13+
#'
14+
#' @template param_type
15+
#' @param ... Additional parameters passed down to `ggforest`.
16+
#'
17+
#' @return [ggplot2::ggplot()].
18+
#'
19+
#' @export
20+
#' @examples
21+
#' \donttest{
22+
#' if (requireNamespace("mlr3proba")) {
23+
#' library(mlr3proba)
24+
#' library(mlr3viz)
25+
#'
26+
#' task = tsk("lung")
27+
#' learner = lrn("surv.coxph")
28+
#' learner$train(task)
29+
#' autoplot(learner)
30+
#' }
31+
#' }
32+
autoplot.LearnerSurvCoxPH = function(object, type = "ggforest", ...) {
33+
assert_class(object, classes = "LearnerSurvCoxPH", null.ok = FALSE)
34+
assert_has_model(object)
35+
36+
switch(type,
37+
"ggforest" = {
38+
require_namespaces("survminer")
39+
suppressWarnings(survminer::ggforest(object$model, ...))
40+
},
41+
42+
stopf("Unknown plot type '%s'", type)
43+
)
44+
}
45+
46+
#' @export
47+
plot.LearnerSurvCoxPH = function(x, ...) {
48+
print(autoplot(x, ...))
49+
}

man/autoplot.BenchmarkResult.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/autoplot.EnsembleFSResult.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/autoplot.LearnerClustHierarchical.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/autoplot.LearnerSurvCoxPH.Rd

Lines changed: 41 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr3viz-package.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_BenchmarkResult.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,19 @@ test_that("holdout roc plot (#54)", {
4848

4949
expect_doppelganger("bmr_holdout_roc", p)
5050
})
51+
52+
skip_if_not_installed("mlr3inference")
53+
skip_if_not_installed("rpart")
54+
55+
test_that("CI plot", {
56+
bmr = benchmark(benchmark_grid(tsks(c("mtcars", "boston_housing")),
57+
lrns(c("regr.featureless", "regr.rpart")), rsmp("holdout")))
58+
59+
p = autoplot(bmr, "ci", msr("ci", "regr.mse"))
60+
expect_true(is.ggplot(p))
61+
expect_doppelganger("bmr_holdout_ci", p)
62+
63+
bmr = benchmark(benchmark_grid(tsk("iris"), lrn("classif.rpart"),
64+
rsmps(c("holdout", "cv"))))
65+
expect_error(autoplot(bmr, "ci", msr("ci", "classif.acc")), "one resampling method")
66+
})

0 commit comments

Comments
 (0)