diff --git a/NAMESPACE b/NAMESPACE index 53dbbfee4..53c328ef9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,7 @@ export(LearnerClassifRandomForest) export(LearnerClassifRandomForestSRC) export(LearnerClassifRandomForestWeka) export(LearnerClassifRandomPlantedForest) +export(LearnerClassifTabNet) export(LearnerClassifRandomTree) export(LearnerClassifSGD) export(LearnerClassifSMO) @@ -91,6 +92,7 @@ export(LearnerRegrRandomForest) export(LearnerRegrRandomForestSRC) export(LearnerRegrRandomForestWeka) export(LearnerRegrRandomPlantedForest) +export(LearnerRegrTabNet) export(LearnerRegrRandomTree) export(LearnerRegrSGD) export(LearnerRegrSMOreg) diff --git a/R/bibentries.R b/R/bibentries.R index 77360bf3e..44305d780 100644 --- a/R/bibentries.R +++ b/R/bibentries.R @@ -586,6 +586,14 @@ bibentries = c( # nolint start month = "01", journal = "University of California, Berkeley" ), + arik2021tabnet = bibentry("inproceedings", + title = "Tabnet: Attentive interpretable tabular learning", + author = "Ar\xc4\xb1k, Sercan O and Pfister, Tomas", + booktitle = "AAAI", + volume = "35", + pages = "6679--6687", + year = "2021" + ), barnwal2022 = bibentry("article", title = "Survival Regression with Accelerated Failure Time Model in XGBoost", author = "Barnwal Avinash, Cho Hyunsu and Hocking Toby", diff --git a/R/learner_tabnet_classif_tabnet.R b/R/learner_tabnet_classif_tabnet.R new file mode 100644 index 000000000..0efbb6afb --- /dev/null +++ b/R/learner_tabnet_classif_tabnet.R @@ -0,0 +1,102 @@ +#' @title Classification TabNet Learner +#' @author Lukas Burk +#' @name mlr_learners_classif.tabnet +#' +#' @template class_learner +#' @templateVar id classif.tabnet +#' @templateVar caller tabnet +#' @references +#' `r format_bib("arik2021tabnet")` +#' +#' @template seealso_learner +#' @export +#' @examples +#' \dontrun{ +#' library(mlr3) +#' library(mlr3torch) +#' task = tsk("german_credit") +#' lrn = lrn("classif.tabnet") +#' +#' lrn$param_set$values$epochs = 10 +#' lrn$param_set$values$attention_width = 8 +#' lrn$train(task) +#' } +LearnerClassifTabNet = R6Class("LearnerClassifTabNet", + inherit = LearnerClassif, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + ps = params_tabnet() + super$initialize( + id = "classif.tabnet", + packages = "tabnet", + feature_types = c("logical", "integer", "numeric", "factor", "ordered"), + predict_types = c("response", "prob"), + param_set = ps, + properties = c("importance", "multiclass", "twoclass"), + man = "mlr3extralearners::mlr_learners_classif.tabnet", + label = "Attentive Interpretable Tabular Network" + ) + }, + + #' @description + #' The importance scores are extracted from the slot `.$model$fit$importances`. + #' @return Named `numeric()`. + importance = function() { + if (is.null(self$model)) { + stopf("No model stored") + } + imp = self$model$fit$importances + sort(stats::setNames(imp$importance, imp$variables), decreasing = TRUE) + } + ), + private = list( + .train = function(task) { + # get parameters for training + pars = self$param_set$get_values(tags = "train") + pars_threads = pars$num_threads + pars$num_threads = NULL + + # Set number of threads + torch::torch_set_num_threads(pars_threads) + + # set column names to ensure consistency in fit and predict + self$state$feature_names = task$feature_names + + # use the mlr3misc::invoke function (it's similar to do.call()) + mlr3misc::invoke(tabnet::tabnet_fit, + x = task$data(cols = task$feature_names), + y = task$data(cols = task$target_names), + .args = pars + ) + }, + + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + newdata = task$data(cols = self$state$feature_names) + + if (self$predict_type == "response") { + pred = mlr3misc::invoke(predict, self$model, new_data = newdata, + type = "class", .args = pars) + + list(response = pred[[".pred_class"]]) + } else { + pred = mlr3misc::invoke(predict, self$model, new_data = newdata, + type = "prob", .args = pars) + + # Result will be a df with one column per variable with names '.pred_' + # we want the names without ".pred" + names(pred) = sub(pattern = ".pred_", replacement = "", names(pred)) + + list(prob = as.matrix(pred)) + } + + } + ) +) + +.extralrns_dict$add("classif.tabnet", LearnerClassifTabNet) diff --git a/R/learner_tabnet_regr_tabnet.R b/R/learner_tabnet_regr_tabnet.R new file mode 100644 index 000000000..08d8491e7 --- /dev/null +++ b/R/learner_tabnet_regr_tabnet.R @@ -0,0 +1,97 @@ +#' @title Regression TabNet Learner +#' @author Lukas Burk +#' @name mlr_learners_regr.tabnet +#' +#' @template class_learner +#' @templateVar id regr.tabnet +#' @templateVar caller tabnet +#' @references +#' `r format_bib("arik2021tabnet")` +#' +#' @template seealso_learner +#' @export +#' @examples +#' \dontrun{ +#' library(mlr3) +#' library(mlr3torch) +#' +#' task = tsk("boston_housing") +#' +#' # Creating a learner & training on example task +#' lrn = lrn("regr.tabnet") +#' +#' lrn$param_set$values$epochs = 10 +#' lrn$train(task) +#' +#' # Predict on training data, get RMSE +#' predictions = lrn$predict(task) +#' predictions$score(msr("regr.rmse")) +#' } +LearnerRegrTabNet = R6::R6Class("LearnerRegrTabnet", + inherit = LearnerRegr, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function() { + ps = params_tabnet() + + super$initialize( + id = "regr.tabnet", + packages = "tabnet", + feature_types = c("logical", "integer", "numeric", "factor", "ordered"), + param_set = ps, + properties = c("importance"), + man = "mlr3torch::mlr_learners_regr.tabnet", + label = "Attentive Interpretable Tabular Network" + ) + }, + + #' @description + #' The importance scores are extracted from the slot `.$model$fit$importances`. + #' @return Named `numeric()`. + importance = function() { + if (is.null(self$model)) { + stopf("No model stored") + } + imp = self$model$fit$importances + sort(stats::setNames(imp$importance, imp$variables), decreasing = TRUE) + } + ), + private = list( + .train = function(task) { + # get parameters for training + pars = self$param_set$get_values(tags = "train") + pars_threads = pars$num_threads + pars$num_threads = NULL + + # Set number of threads + torch::torch_set_num_threads(pars_threads) + + # set column names to ensure consistency in fit and predict + self$state$feature_names = task$feature_names + + # use the mlr3misc::invoke function (it's similar to do.call()) + mlr3misc::invoke(tabnet::tabnet_fit, + x = task$data(cols = task$feature_names), + y = task$data(cols = task$target_names), + .args = pars + ) + }, + .predict = function(task) { + # get parameters with tag "predict" + pars = self$param_set$get_values(tags = "predict") + + # get newdata and ensure same ordering in train and predict + newdata = task$data(cols = self$state$feature_names) + + pred = mlr3misc::invoke(predict, self$model, + new_data = newdata, + .args = pars + ) + + list(response = pred[[".pred"]]) + } + ) +) + +.extralrns_dict$add("regr.tabnet", LearnerRegrTabNet) diff --git a/R/paramset_tabnet.R b/R/paramset_tabnet.R new file mode 100644 index 000000000..ebcf71543 --- /dev/null +++ b/R/paramset_tabnet.R @@ -0,0 +1,48 @@ +params_tabnet = function() { + param_set = ParamSet$new(list( + num_threads= p_int(default = 1L, lower = 1L, upper = Inf, tags = c("train", "threads")), + batch_size = p_int(default = 256L, lower = 1L, upper = Inf, tags = "train"), + penalty = p_dbl(default = 0.001, tags = "train"), + + # FIXME: NULL here is used for bool FALSE, not sure what to do there. + clip_value = p_uty(default = NULL, tags = "train"), + loss = p_fct(default = "auto", levels = c("auto", "mse", "cross_entropy"), tags = "train"), + epochs = p_int(default = 5L, lower = 1L, upper = Inf, tags = "train"), + drop_last = p_lgl(default = FALSE, tags = "train"), + decision_width = p_int(default = 8L, lower = 1L, upper = Inf, tags = "train"), + attention_width = p_int(default = 8L, lower = 1L, upper = Inf, tags = "train"), + num_steps = p_int(default = 3L, lower = 1L, upper = Inf, tags = "train"), + feature_reusage = p_dbl(default = 1.3, lower = 0, upper = Inf, tags = "train"), + mask_typs = p_fct(default = "sparsemax", levels = c("sparsemax", "entmax"), tags = "train"), + virtual_batch_size = p_int(default = 128L, lower = 1L, upper = Inf, tags = "train"), + valid_split = p_dbl(default = 0, lower = 0, upper = 1, tags = "train"), + learn_rate = p_dbl(default = 0.02, lower = 0, upper = 1, tags = "train"), + + # FIXME: Currently either 'adam' or arbitrary optimizer function according to docs + optimizer = p_uty(default = "adam", tags = "train"), + + # FIXME: This is either NULL or a function or explicit "steps", needs custom_check fun + lr_scheduler = p_uty(default = NULL, tags = "train"), + + lr_decay = p_dbl(default = 0.1, lower = 0, upper = 1, tags = "train"), + step_size = p_int(default = 30L, lower = 1L, upper = Inf, tags = "train"), + checkpoint_epochs = p_int(default = 10L, lower = 1L, upper = Inf, tags = "train"), + cat_emb_dim = p_int(default = 1L, lower = 0L, upper = Inf, tags = "train"), + num_independent = p_int(default = 2L, lower = 0, upper = Inf, tags = "train"), + num_shared = p_int(default = 2L, lower = 0, upper = Inf, tags = "train"), + momentum = p_dbl(default = 0.02, lower = 0, upper = 1, tags = "train"), + pretraining_ratio = p_dbl(default = 0.5, lower = 0, upper = 1, tags = "train"), + verbose = p_lgl(default = FALSE, tags = "train"), + devics = p_fct(default = "auto", levels = c("auto", "cpu", "cuda"), tags = "train"), + importance_sample_size = p_int(lower = 0, upper = 1e5, special_vals = list(NULL), tags = "train") + )) + + # Set param values that differ from default in tabnet_fit + param_set$values = list( + num_threads = 1L, + # clip_value = NULL, + decision_width = 8L, + attention_width = 8L + ) + return(param_set) +} diff --git a/man/mlr_learners_classif.tabnet.Rd b/man/mlr_learners_classif.tabnet.Rd new file mode 100644 index 000000000..cb435fe9a --- /dev/null +++ b/man/mlr_learners_classif.tabnet.Rd @@ -0,0 +1,162 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/learner_tabnet_classif_tabnet.R +\name{mlr_learners_classif.tabnet} +\alias{mlr_learners_classif.tabnet} +\alias{LearnerClassifTabNet} +\title{Classification TabNet Learner} +\description{ +Classification TabNet Learner + +Classification TabNet Learner +} +\section{Dictionary}{ + This \link{Learner} can be instantiated via the +\link[mlr3misc:Dictionary]{dictionary} \link{mlr_learners} or with the associated +sugar function \code{\link[=lrn]{lrn()}}: + +\if{html}{\out{
}}\preformatted{mlr_learners$get("classif.tabnet") +lrn("classif.tabnet") +}\if{html}{\out{
}} +} + +\section{Meta Information}{ + +\itemize{ +\item Task type: \dQuote{classif} +\item Predict Types: \dQuote{response}, \dQuote{prob} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{factor}, \dQuote{ordered} +\item Required Packages: \CRANpkg{mlr3}, \CRANpkg{tabnet} +} +} + +\section{Parameters}{ +\tabular{lllll}{ + Id \tab Type \tab Default \tab Levels \tab Range \cr + num_threads \tab integer \tab 1 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + batch_size \tab integer \tab 256 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + penalty \tab numeric \tab 0.001 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + clip_value \tab untyped \tab \tab \tab - \cr + loss \tab character \tab auto \tab auto, mse, cross_entropy \tab - \cr + epochs \tab integer \tab 5 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + drop_last \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + decision_width \tab integer \tab 8 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + attention_width \tab integer \tab 8 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + num_steps \tab integer \tab 3 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + feature_reusage \tab numeric \tab 1.3 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + mask_type \tab character \tab sparsemax \tab sparsemax, entmax \tab - \cr + virtual_batch_size \tab integer \tab 128 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + valid_split \tab numeric \tab 0 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + learn_rate \tab numeric \tab 0.02 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + optimizer \tab untyped \tab adam \tab \tab - \cr + lr_scheduler \tab untyped \tab \tab \tab - \cr + lr_decay \tab numeric \tab 0.1 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + step_size \tab integer \tab 30 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + checkpoint_epochs \tab integer \tab 10 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + cat_emb_dim \tab integer \tab 1 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + num_independent \tab integer \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + num_shared \tab integer \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + momentum \tab numeric \tab 0.02 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + pretraining_ratio \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + verbose \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + device \tab character \tab auto \tab auto, cpu, cuda \tab - \cr + importance_sample_size \tab integer \tab - \tab \tab \eqn{[0, 1e+05]}{[0, 1e+05]} \cr +} +} + +\examples{ +\dontrun{ +library(mlr3) +library(mlr3torch) +task = tsk("german_credit") +lrn = lrn("classif.tabnet") + +lrn$param_set$values$epochs = 10 +lrn$param_set$values$attention_width = 8 +lrn$train(task) +} +} +\references{ +Arık, O S, Pfister, Tomas (2021). +\dQuote{Tabnet: Attentive interpretable tabular learning.} +In \emph{AAAI}, volume 35, 6679--6687. +} +\seealso{ +\itemize{ +\item \link[mlr3misc:Dictionary]{Dictionary} of \link[mlr3:Learner]{Learners}: \link[mlr3:mlr_learners]{mlr3::mlr_learners}. +\item \code{as.data.table(mlr_learners)} for a table of available \link[=Learner]{Learners} in the running session (depending on the loaded packages). +\item Chapter in the \href{https://mlr3book.mlr-org.com/}{mlr3book}: \url{https://mlr3book.mlr-org.com/basics.html#learners} +\item \CRANpkg{mlr3learners} for a selection of recommended learners. +\item \CRANpkg{mlr3cluster} for unsupervised clustering learners. +\item \CRANpkg{mlr3pipelines} to combine learners with pre- and postprocessing steps. +\item \CRANpkg{mlr3tuning} for tuning of hyperparameters, \CRANpkg{mlr3tuningspaces} for established default tuning spaces. +} +} +\author{ +Lukas Burk +} +\section{Super classes}{ +\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3:LearnerClassif]{mlr3::LearnerClassif}} -> \code{LearnerClassifTabNet} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-LearnerClassifTabNet-new}{\code{LearnerClassifTabNet$new()}} +\item \href{#method-LearnerClassifTabNet-importance}{\code{LearnerClassifTabNet$importance()}} +\item \href{#method-LearnerClassifTabNet-clone}{\code{LearnerClassifTabNet$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerClassifTabNet-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerClassifTabNet$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerClassifTabNet-importance}{}}} +\subsection{Method \code{importance()}}{ +The importance scores are extracted from the slot \code{.$model$fit$importances}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerClassifTabNet$importance()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Named \code{numeric()}. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerClassifTabNet-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerClassifTabNet$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/mlr_learners_regr.tabnet.Rd b/man/mlr_learners_regr.tabnet.Rd new file mode 100644 index 000000000..c247c7eb6 --- /dev/null +++ b/man/mlr_learners_regr.tabnet.Rd @@ -0,0 +1,168 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/learner_tabnet_regr_tabnet.R +\name{mlr_learners_regr.tabnet} +\alias{mlr_learners_regr.tabnet} +\alias{LearnerRegrTabNet} +\title{Regression TabNet Learner} +\description{ +Regression TabNet Learner + +Regression TabNet Learner +} +\section{Dictionary}{ + This \link{Learner} can be instantiated via the +\link[mlr3misc:Dictionary]{dictionary} \link{mlr_learners} or with the associated +sugar function \code{\link[=lrn]{lrn()}}: + +\if{html}{\out{
}}\preformatted{mlr_learners$get("regr.tabnet") +lrn("regr.tabnet") +}\if{html}{\out{
}} +} + +\section{Meta Information}{ + +\itemize{ +\item Task type: \dQuote{regr} +\item Predict Types: \dQuote{response} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{factor}, \dQuote{ordered} +\item Required Packages: \CRANpkg{mlr3}, \CRANpkg{tabnet} +} +} + +\section{Parameters}{ +\tabular{lllll}{ + Id \tab Type \tab Default \tab Levels \tab Range \cr + num_threads \tab integer \tab 1 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + batch_size \tab integer \tab 256 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + penalty \tab numeric \tab 0.001 \tab \tab \eqn{(-\infty, \infty)}{(-Inf, Inf)} \cr + clip_value \tab untyped \tab \tab \tab - \cr + loss \tab character \tab auto \tab auto, mse, cross_entropy \tab - \cr + epochs \tab integer \tab 5 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + drop_last \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + decision_width \tab integer \tab 8 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + attention_width \tab integer \tab 8 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + num_steps \tab integer \tab 3 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + feature_reusage \tab numeric \tab 1.3 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + mask_type \tab character \tab sparsemax \tab sparsemax, entmax \tab - \cr + virtual_batch_size \tab integer \tab 128 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + valid_split \tab numeric \tab 0 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + learn_rate \tab numeric \tab 0.02 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + optimizer \tab untyped \tab adam \tab \tab - \cr + lr_scheduler \tab untyped \tab \tab \tab - \cr + lr_decay \tab numeric \tab 0.1 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + step_size \tab integer \tab 30 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + checkpoint_epochs \tab integer \tab 10 \tab \tab \eqn{[1, \infty)}{[1, Inf)} \cr + cat_emb_dim \tab integer \tab 1 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + num_independent \tab integer \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + num_shared \tab integer \tab 2 \tab \tab \eqn{[0, \infty)}{[0, Inf)} \cr + momentum \tab numeric \tab 0.02 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + pretraining_ratio \tab numeric \tab 0.5 \tab \tab \eqn{[0, 1]}{[0, 1]} \cr + verbose \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr + device \tab character \tab auto \tab auto, cpu, cuda \tab - \cr + importance_sample_size \tab integer \tab - \tab \tab \eqn{[0, 1e+05]}{[0, 1e+05]} \cr +} +} + +\examples{ +\dontrun{ +library(mlr3) +library(mlr3torch) + +task = tsk("boston_housing") + +# Creating a learner & training on example task +lrn = lrn("regr.tabnet") + +lrn$param_set$values$epochs = 10 +lrn$train(task) + +# Predict on training data, get RMSE +predictions = lrn$predict(task) +predictions$score(msr("regr.rmse")) +} +} +\references{ +Arık, O S, Pfister, Tomas (2021). +\dQuote{Tabnet: Attentive interpretable tabular learning.} +In \emph{AAAI}, volume 35, 6679--6687. +} +\seealso{ +\itemize{ +\item \link[mlr3misc:Dictionary]{Dictionary} of \link[mlr3:Learner]{Learners}: \link[mlr3:mlr_learners]{mlr3::mlr_learners}. +\item \code{as.data.table(mlr_learners)} for a table of available \link[=Learner]{Learners} in the running session (depending on the loaded packages). +\item Chapter in the \href{https://mlr3book.mlr-org.com/}{mlr3book}: \url{https://mlr3book.mlr-org.com/basics.html#learners} +\item \CRANpkg{mlr3learners} for a selection of recommended learners. +\item \CRANpkg{mlr3cluster} for unsupervised clustering learners. +\item \CRANpkg{mlr3pipelines} to combine learners with pre- and postprocessing steps. +\item \CRANpkg{mlr3tuning} for tuning of hyperparameters, \CRANpkg{mlr3tuningspaces} for established default tuning spaces. +} +} +\author{ +Lukas Burk +} +\section{Super classes}{ +\code{\link[mlr3:Learner]{mlr3::Learner}} -> \code{\link[mlr3:LearnerRegr]{mlr3::LearnerRegr}} -> \code{LearnerRegrTabnet} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-LearnerRegrTabnet-new}{\code{LearnerRegrTabNet$new()}} +\item \href{#method-LearnerRegrTabnet-importance}{\code{LearnerRegrTabNet$importance()}} +\item \href{#method-LearnerRegrTabnet-clone}{\code{LearnerRegrTabNet$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerRegrTabnet-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerRegrTabNet$new()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerRegrTabnet-importance}{}}} +\subsection{Method \code{importance()}}{ +The importance scores are extracted from the slot \code{.$model$fit$importances}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerRegrTabNet$importance()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +Named \code{numeric()}. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-LearnerRegrTabnet-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{LearnerRegrTabNet$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/params_tabnet.Rd b/man/params_tabnet.Rd new file mode 100644 index 000000000..88c35669f --- /dev/null +++ b/man/params_tabnet.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/paramset_tabnet.R +\name{params_tabnet} +\alias{params_tabnet} +\title{Wrapper for the TabNet ParamSet} +\usage{ +params_tabnet() +} +\value{ +Object of classes \verb{ParamSet" "R6"}, suitable for use with +} +\description{ +Used to de-duplicate learner setup. +} +\keyword{internal} diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 54d6785a2..ee21de7b1 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -3,8 +3,10 @@ library(mlr3) library(mlr3extralearners) library(mlr3proba) -lapply(list.files(system.file("testthat", package = "mlr3"), - pattern = "^helper.*\\.[rR]", full.names = TRUE), source) +tmp = list.files(system.file("testthat", package = "mlr3"), + pattern = "^helper.*\\.[rR]", full.names = TRUE) +tmp = tmp[basename(tmp) != "helper_debugging.R"] +lapply(tmp, source) lapply(list.files(system.file("testthat", package = "mlr3proba"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source) diff --git a/tests/testthat/test_paramset_tabnet_regr_tabnet.R b/tests/testthat/test_paramset_tabnet_regr_tabnet.R new file mode 100644 index 000000000..c949b07c5 --- /dev/null +++ b/tests/testthat/test_paramset_tabnet_regr_tabnet.R @@ -0,0 +1,21 @@ +test_that("paramtest regr.tabnet train", { + learner = lrn("regr.tabnet") + fun = tabnet::tabnet_fit # nolint + + exclude = c( + ) + + paramtest = run_paramtest(learner, fun, exclude, tag = "train") + expect_paramtest(paramtest) +}) + +test_that("paramtest regr.earth predict", { + learner = lrn("regr.tabnet") + fun = tabnet:::predict.tabnet_fit # nolint + + exclude = c( + ) + + paramtest = run_paramtest(learner, fun, exclude, tag = "predict") + expect_paramtest(paramtest) +}) diff --git a/tests/testthat/test_tabnet_classif_tabnet.R b/tests/testthat/test_tabnet_classif_tabnet.R new file mode 100644 index 000000000..8381951c1 --- /dev/null +++ b/tests/testthat/test_tabnet_classif_tabnet.R @@ -0,0 +1,6 @@ +test_that("autotest", { + learner = lrn("classif.tabnet", device = "cpu", epochs = 20) + expect_learner(learner) + result = run_autotest(learner) + expect_true(result, info = result$error) +}) diff --git a/tests/testthat/test_tabnet_regr_tabnet.R b/tests/testthat/test_tabnet_regr_tabnet.R new file mode 100644 index 000000000..fa5f75c64 --- /dev/null +++ b/tests/testthat/test_tabnet_regr_tabnet.R @@ -0,0 +1,6 @@ +test_that("autotest", { + learner = lrn("regr.tabnet", device = "cpu", epochs = 30) + expect_learner(learner) + result = run_autotest(learner) + expect_true(result, info = result$error) +})