Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Suggests:
curl,
dbarts,
earth,
elasticnet,
evtree,
ExhaustiveSearch,
fastai (>= 2.2.2),
Expand All @@ -91,6 +92,7 @@ Suggests:
lgr,
LiblineaR,
lightgbm (>= 4.5.0),
MASS,
lme4 (>= 1.1-38),
mboost (>= 2.9-10),
mda,
Expand Down Expand Up @@ -124,6 +126,7 @@ Suggests:
RWeka,
sandwich,
sda,
sparseLDA,
sparsediscrim,
stats,
stepPlr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export(LearnerClassifSMO)
export(LearnerClassifSda)
export(LearnerClassifSdlda)
export(LearnerClassifSimpleLogistic)
export(LearnerClassifSparseLDA)
export(LearnerClassifStepPlr)
export(LearnerClassifTabPFN)
export(LearnerClassifVotedPerceptron)
Expand Down
10 changes: 10 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,16 @@ bibentries = c( # nolint start
publisher = "ASA Websites",
doi = "10.1198/016214502753479248"
),
clemmensen2011sda = bibentry("article",
title = "Sparse discriminant analysis",
author = "Clemmensen, Line and Hastie, Trevor and Witten, Daniela and Ersboll, Bjarne",
year = "2011",
journal = "Journal of the American Statistical Association",
volume = "106",
number = "496",
pages = "1519--1531",
doi = "10.1198/jasa.2011.tm09728"
),
Srivastava2007mdeb = bibentry("article",
title = "Comparison of Discrimination Methods for High Dimensional Data",
author = "Srivastava, M. and Kubokawa, T.",
Expand Down
84 changes: 84 additions & 0 deletions R/learner_sparseLDA_classif_sparseLDA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#' @title Sparse Discriminant Analysis
#' @author awinterstetter
#' @name mlr_learners_classif.sparseLDA
#'
#' @description
#' Sparse Linear Discriminant Analysis for classification.
#' Calls [sparseLDA::sda()] from \CRANpkg{sparseLDA}.
#'
#' @section Custom mlr3 parameters:
#' - `Q` is set internally to `min(n_features, n_classes - 1)` when not supplied.
#' - `stop` is not exposed because it depends on the task.
#'
#' @templateVar id classif.sparseLDA
#' @template learner
#'
#' @references
#' `r format_bib("clemmensen2011sda")`
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerClassifSparseLDA = R6Class("LearnerClassifSparseLDA",
inherit = LearnerClassif,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
lambda = p_dbl(default = 1e-6, lower = 0, tags = "train"),
maxIte = p_int(default = 100L, lower = 0L, tags = "train"),
tol = p_dbl(default = 1e-6, lower = 0, tags = "train"),
trace = p_lgl(default = FALSE, tags = "train")
)

super$initialize(
id = "classif.sparseLDA",
packages = c("sparseLDA", "MASS", "elasticnet"),
feature_types = c("integer", "numeric"),
predict_types = c("response", "prob"),
param_set = param_set,
properties = c("multiclass", "twoclass"),
man = "mlr3extralearners::mlr_learners_classif.sparseLDA",
label = "Sparse Discriminant Analysis"
)
}
),
private = list(
.train = function(task) {
pars = self$param_set$get_values(tags = "train")

target = task$truth()
lvls = levels(target)
y = sapply(lvls, function(lvl) as.integer(as.character(target) == lvl))
colnames(y) = lvls
x = as.matrix(task$data(cols = task$feature_names))

if (is.null(pars$Q)) {
max_q = min(ncol(x), length(lvls) - 1L)
if (max_q >= 1L) {
pars$Q = max_q
}
}

invoke(sparseLDA::sda,
x = x,
y = y,
.args = pars
)
},
.predict = function(task) {
newdata = ordered_features(task, self)

pred = predict(object = self$model, newdata = newdata)

if (self$predict_type == "response") {
list(response = pred$class)
} else {
list(prob = pred$posterior)
}
}
)
)

.extralrns_dict$add("classif.sparseLDA", LearnerClassifSparseLDA)
147 changes: 147 additions & 0 deletions man/mlr_learners_classif.sparseLDA.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions tests/testthat/test_paramtest_sparseLDA_classif_sparseLDA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
skip_if_not_installed("sparseLDA")
skip_if_not_installed("elasticnet")
skip_if_not_installed("MASS")

test_that("classif.sparseLDA train", {
learner = lrn("classif.sparseLDA")
fun = sparseLDA:::sda.default # nolint
exclude = c(
"x", # handled internally
"y", # handled internally
"Q", # not exposed
"stop" # not exposed
)

paramtest = run_paramtest(learner, fun, exclude, tag = "train")
expect_paramtest(paramtest)
})

test_that("classif.sparseLDA predict", {
learner = lrn("classif.sparseLDA")
fun = sparseLDA:::predict.sda # nolint
exclude = c(
"object", # handled internally
"newdata", # handled internally
"type" # handled internally
)

paramtest = run_paramtest(learner, fun, exclude, tag = "predict")
expect_paramtest(paramtest)
})
1 change: 0 additions & 1 deletion tests/testthat/test_sda_classif_sda.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ skip_if_not_installed("sda")
test_that("autotest", {
learner = lrn("classif.sda")
expect_learner(learner)
# note that you can skip tests using the exclude argument
capture.output({
result = run_autotest(learner)
})
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test_sparseLDA_classif_sparseLDA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
skip_if_not_installed("sparseLDA")
skip_if_not_installed("elasticnet")
skip_if_not_installed("MASS")

test_that("autotest", {
learner = lrn("classif.sparseLDA")
expect_learner(learner)
result = run_autotest(learner, exclude = "feat_single_integer_binary")
expect_true(result, info = result$error)
})
Loading