diff --git a/DESCRIPTION b/DESCRIPTION index 300c7515..2139d2e2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -50,6 +50,9 @@ Suggests: rpart, testthat (>= 3.0.0), xgboost +Remotes: + mlr-org/mlr3, + mlr-org/bbotk VignetteBuilder: knitr Config/testthat/edition: 3 diff --git a/R/ArchiveAsyncTuning.R b/R/ArchiveAsyncTuning.R index ad0e208d..2e272a27 100644 --- a/R/ArchiveAsyncTuning.R +++ b/R/ArchiveAsyncTuning.R @@ -170,8 +170,7 @@ ArchiveAsyncTuning = R6Class("ArchiveAsyncTuning", # cache benchmark result if (self$rush$n_finished_tasks > private$.benchmark_result$n_resample_results) { bmrs = map(self$finished_data$resample_result, as_benchmark_result) - init = BenchmarkResult$new() - private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs, init = init) + private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs) } private$.benchmark_result } diff --git a/R/CallbackAsyncTuning.R b/R/CallbackAsyncTuning.R index 616addad..5de7e246 100644 --- a/R/CallbackAsyncTuning.R +++ b/R/CallbackAsyncTuning.R @@ -1,4 +1,4 @@ -#' @title Create Asynchronous Tuning Callback +#' @title Asynchronous Tuning Callback #' #' @description #' Specialized [bbotk::CallbackAsync] for asynchronous tuning. @@ -17,6 +17,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning", #' Called in `ObjectiveTuningAsync$eval()`. on_eval_after_xs = NULL, + #' @field on_resample_begin (`function()`)\cr + #' Stage called at the beginning of an evaluation. + #' Called in `workhorse()` (internal). + on_resample_begin = NULL, + + #' @field on_resample_before_train (`function()`)\cr + #' Stage called before training the learner. + #' Called in `workhorse()` (internal). + on_resample_before_train = NULL, + + #' @field on_resample_before_predict (`function()`)\cr + #' Stage called before predicting. + #' Called in `workhorse()` (internal). + on_resample_before_predict = NULL, + + #' @field on_resample_end (`function()`)\cr + #' Stage called at the end of an evaluation. + #' Called in `workhorse()` (internal). + on_resample_end = NULL, + #' @field on_eval_after_resample (`function()`)\cr #' Stage called after hyperparameter configurations are evaluated. #' Called in `ObjectiveTuningAsync$eval()`. @@ -52,6 +72,12 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning", #' - on_optimizer_before_eval #' Start Evaluation #' - on_eval_after_xs +#' Start Resampling Iteration +#' - on_resample_begin +#' - on_resample_before_train +#' - on_resample_before_predict +#' - on_resample_end +#' End Resampling Iteration #' - on_eval_after_resample #' - on_eval_before_archive #' End Evaluation @@ -72,7 +98,7 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning", #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. #' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. -#' Tuning callbacks access [ContextAsyncTuning]. +#' Tuning callbacks access [ContextAsyncTuning] and [mlr3::ContextResample]. #' #' @param id (`character(1)`)\cr #' Identifier for the new instance. @@ -101,6 +127,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning", #' Called in `ObjectiveTuningAsync$eval()`. #' The functions must have two arguments named `callback` and `context`. #' The argument of `$.eval(xs)` is available in the `context`. +#' @param on_resample_begin (`function()`)\cr +#' Stage called at the beginning of a resampling iteration. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_before_train (`function()`)\cr +#' Stage called before training the learner. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_before_predict (`function()`)\cr +#' Stage called before predicting. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_end (`function()`)\cr +#' Stage called at the end of a resampling iteration. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. #' @param on_eval_after_resample (`function()`)\cr #' Stage called after a hyperparameter configuration is evaluated. #' Called in `ObjectiveTuningAsync$eval()`. @@ -152,6 +198,10 @@ callback_async_tuning = function( on_worker_begin = NULL, on_optimizer_before_eval = NULL, on_eval_after_xs = NULL, + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL, on_eval_after_resample = NULL, on_eval_before_archive = NULL, on_optimizer_after_eval = NULL, @@ -167,6 +217,10 @@ callback_async_tuning = function( on_worker_begin, on_optimizer_before_eval, on_eval_after_xs, + on_resample_begin, + on_resample_before_train, + on_resample_before_predict, + on_resample_end, on_eval_after_resample, on_eval_before_archive, on_optimizer_after_eval, @@ -181,6 +235,10 @@ callback_async_tuning = function( "on_worker_begin", "on_optimizer_before_eval", "on_eval_after_xs", + "on_resample_begin", + "on_resample_before_train", + "on_resample_before_predict", + "on_resample_end", "on_eval_after_resample", "on_eval_before_archive", "on_optimizer_after_eval", diff --git a/R/CallbackBatchTuning.R b/R/CallbackBatchTuning.R index 2de95b4d..d1c2047b 100644 --- a/R/CallbackBatchTuning.R +++ b/R/CallbackBatchTuning.R @@ -24,6 +24,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning", #' Called in `ObjectiveTuningBatch$eval_many()`. on_eval_after_design = NULL, + #' @field on_resample_begin (`function()`)\cr + #' Stage called at the beginning of an evaluation. + #' Called in `workhorse()` (internal). + on_resample_begin = NULL, + + #' @field on_resample_before_train (`function()`)\cr + #' Stage called before training the learner. + #' Called in `workhorse()` (internal). + on_resample_before_train = NULL, + + #' @field on_resample_before_predict (`function()`)\cr + #' Stage called before predicting. + #' Called in `workhorse()` (internal). + on_resample_before_predict = NULL, + + #' @field on_resample_end (`function()`)\cr + #' Stage called at the end of an evaluation. + #' Called in `workhorse()` (internal). + on_resample_end = NULL, + #' @field on_eval_after_benchmark (`function()`)\cr #' Stage called after hyperparameter configurations are evaluated. #' Called in `ObjectiveTuningBatch$eval_many()`. @@ -57,6 +77,12 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning", #' - on_optimizer_before_eval #' Start Evaluation #' - on_eval_after_design +#' Start Resampling Iteration +#' - on_resample_begin +#' - on_resample_before_train +#' - on_resample_before_predict +#' - on_resample_end +#' End Resampling Iteration #' - on_eval_after_benchmark #' - on_eval_before_archive #' End Evaluation @@ -70,7 +96,7 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning", #' ``` #' #' See also the section on parameters for more information on the stages. -#' A tuning callback works with [ContextBatchTuning]. +#' A tuning callback works with [ContextBatchTuning] and [mlr3::ContextResample]. #' #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. @@ -100,6 +126,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning", #' The functions must have two arguments named `callback` and `context`. #' The arguments of `$eval_many(xss, resampling)` are available in `context`. #' Additionally, the `design` is available in `context`. +#' @param on_resample_begin (`function()`)\cr +#' Stage called at the beginning of a resampling iteration. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_before_train (`function()`)\cr +#' Stage called before training the learner. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_before_predict (`function()`)\cr +#' Stage called before predicting. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_resample_end (`function()`)\cr +#' Stage called at the end of a resampling iteration. +#' Called in `workhorse()` (internal). +#' See also [mlr3::callback_resample()]. +#' The functions must have two arguments named `callback` and `context`. #' @param on_eval_after_benchmark (`function()`)\cr #' Stage called after hyperparameter configurations are evaluated. #' Called in `ObjectiveTuningBatch$eval_many()`. @@ -150,6 +196,10 @@ callback_batch_tuning = function( on_optimization_begin = NULL, on_optimizer_before_eval = NULL, on_eval_after_design = NULL, + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL, on_eval_after_benchmark = NULL, on_eval_before_archive = NULL, on_optimizer_after_eval = NULL, @@ -163,6 +213,10 @@ callback_batch_tuning = function( on_optimization_begin, on_optimizer_before_eval, on_eval_after_design, + on_resample_begin, + on_resample_before_train, + on_resample_before_predict, + on_resample_end, on_eval_after_benchmark, on_eval_before_archive, on_optimizer_after_eval, @@ -175,6 +229,10 @@ callback_batch_tuning = function( "on_optimization_begin", "on_optimizer_before_eval", "on_eval_after_design", + "on_resample_begin", + "on_resample_before_train", + "on_resample_before_predict", + "on_resample_end", "on_eval_after_benchmark", "on_eval_before_archive", "on_optimizer_after_eval", diff --git a/R/ObjectiveTuningAsync.R b/R/ObjectiveTuningAsync.R index 217ef411..6805558b 100644 --- a/R/ObjectiveTuningAsync.R +++ b/R/ObjectiveTuningAsync.R @@ -31,7 +31,7 @@ ObjectiveTuningAsync = R6Class("ObjectiveTuningAsync", lg$debug("Resampling hyperparameter configuration") # resample hyperparameter configuration - private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0)) + private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0), callbacks = self$callbacks) call_back("on_eval_after_resample", self$callbacks, self$context) lg$debug("Aggregating performance") diff --git a/R/ObjectiveTuningBatch.R b/R/ObjectiveTuningBatch.R index 934fddab..3d8ad39a 100644 --- a/R/ObjectiveTuningBatch.R +++ b/R/ObjectiveTuningBatch.R @@ -74,7 +74,8 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch", private$.benchmark_result = benchmark( design = private$.design, store_models = self$store_models, - clone = character(0)) + clone = character(0), + callbacks = self$callbacks) call_back("on_eval_after_benchmark", self$callbacks, self$context) # aggregate performance scores diff --git a/man/CallbackAsyncTuning.Rd b/man/CallbackAsyncTuning.Rd index 8b36f176..cab809b2 100644 --- a/man/CallbackAsyncTuning.Rd +++ b/man/CallbackAsyncTuning.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/CallbackAsyncTuning.R \name{CallbackAsyncTuning} \alias{CallbackAsyncTuning} -\title{Create Asynchronous Tuning Callback} +\title{Asynchronous Tuning Callback} \description{ Specialized \link[bbotk:CallbackAsync]{bbotk::CallbackAsync} for asynchronous tuning. Callbacks allow to customize the behavior of processes in mlr3tuning. @@ -20,6 +20,22 @@ For more information on tuning callbacks see \code{\link[=callback_async_tuning] Stage called after xs is passed. Called in \code{ObjectiveTuningAsync$eval()}.} +\item{\code{on_resample_begin}}{(\verb{function()})\cr +Stage called at the beginning of an evaluation. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_before_train}}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_before_predict}}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_end}}{(\verb{function()})\cr +Stage called at the end of an evaluation. +Called in \code{workhorse()} (internal).} + \item{\code{on_eval_after_resample}}{(\verb{function()})\cr Stage called after hyperparameter configurations are evaluated. Called in \code{ObjectiveTuningAsync$eval()}.} diff --git a/man/CallbackBatchTuning.Rd b/man/CallbackBatchTuning.Rd index c09643b8..fa79b4f5 100644 --- a/man/CallbackBatchTuning.Rd +++ b/man/CallbackBatchTuning.Rd @@ -28,6 +28,22 @@ callback_batch_tuning("mlr3tuning.backup", Stage called after design is created. Called in \code{ObjectiveTuningBatch$eval_many()}.} +\item{\code{on_resample_begin}}{(\verb{function()})\cr +Stage called at the beginning of an evaluation. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_before_train}}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_before_predict}}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal).} + +\item{\code{on_resample_end}}{(\verb{function()})\cr +Stage called at the end of an evaluation. +Called in \code{workhorse()} (internal).} + \item{\code{on_eval_after_benchmark}}{(\verb{function()})\cr Stage called after hyperparameter configurations are evaluated. Called in \code{ObjectiveTuningBatch$eval_many()}.} diff --git a/man/callback_async_tuning.Rd b/man/callback_async_tuning.Rd index 839688a7..9232a13e 100644 --- a/man/callback_async_tuning.Rd +++ b/man/callback_async_tuning.Rd @@ -12,6 +12,10 @@ callback_async_tuning( on_worker_begin = NULL, on_optimizer_before_eval = NULL, on_eval_after_xs = NULL, + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL, on_eval_after_resample = NULL, on_eval_before_archive = NULL, on_optimizer_after_eval = NULL, @@ -57,6 +61,30 @@ Called in \code{ObjectiveTuningAsync$eval()}. The functions must have two arguments named \code{callback} and \code{context}. The argument of \verb{$.eval(xs)} is available in the \code{context}.} +\item{on_resample_begin}{(\verb{function()})\cr +Stage called at the beginning of a resampling iteration. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_before_train}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_before_predict}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_end}{(\verb{function()})\cr +Stage called at the end of a resampling iteration. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + \item{on_eval_after_resample}{(\verb{function()})\cr Stage called after a hyperparameter configuration is evaluated. Called in \code{ObjectiveTuningAsync$eval()}. @@ -122,6 +150,12 @@ The stages are prefixed with \verb{on_*}. - on_optimizer_before_eval Start Evaluation - on_eval_after_xs + Start Resampling Iteration + - on_resample_begin + - on_resample_before_train + - on_resample_before_predict + - on_resample_end + End Resampling Iteration - on_eval_after_resample - on_eval_before_archive End Evaluation @@ -142,5 +176,5 @@ A tuning callback works with \link{ContextAsyncTuning}. \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. -Tuning callbacks access \link{ContextAsyncTuning}. +Tuning callbacks access \link{ContextAsyncTuning} and \link[mlr3:ContextResample]{mlr3::ContextResample}. } diff --git a/man/callback_batch_tuning.Rd b/man/callback_batch_tuning.Rd index 4d84e52c..d4ef24ec 100644 --- a/man/callback_batch_tuning.Rd +++ b/man/callback_batch_tuning.Rd @@ -11,6 +11,10 @@ callback_batch_tuning( on_optimization_begin = NULL, on_optimizer_before_eval = NULL, on_eval_after_design = NULL, + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL, on_eval_after_benchmark = NULL, on_eval_before_archive = NULL, on_optimizer_after_eval = NULL, @@ -50,6 +54,30 @@ The functions must have two arguments named \code{callback} and \code{context}. The arguments of \verb{$eval_many(xss, resampling)} are available in \code{context}. Additionally, the \code{design} is available in \code{context}.} +\item{on_resample_begin}{(\verb{function()})\cr +Stage called at the beginning of a resampling iteration. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_before_train}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_before_predict}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + +\item{on_resample_end}{(\verb{function()})\cr +Stage called at the end of a resampling iteration. +Called in \code{workhorse()} (internal). +See also \code{\link[mlr3:callback_resample]{mlr3::callback_resample()}}. +The functions must have two arguments named \code{callback} and \code{context}.} + \item{on_eval_after_benchmark}{(\verb{function()})\cr Stage called after hyperparameter configurations are evaluated. Called in \code{ObjectiveTuningBatch$eval_many()}. @@ -111,6 +139,12 @@ The stages are prefixed with \verb{on_*}. - on_optimizer_before_eval Start Evaluation - on_eval_after_design + Start Resampling Iteration + - on_resample_begin + - on_resample_before_train + - on_resample_before_predict + - on_resample_end + End Resampling Iteration - on_eval_after_benchmark - on_eval_before_archive End Evaluation @@ -124,7 +158,7 @@ End Tuning }\if{html}{\out{}} See also the section on parameters for more information on the stages. -A tuning callback works with \link{ContextBatchTuning}. +A tuning callback works with \link{ContextBatchTuning} and \link[mlr3:ContextResample]{mlr3::ContextResample}. } \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. diff --git a/tests/testthat/test_ArchiveAsyncTuning.R b/tests/testthat/test_ArchiveAsyncTuning.R index 8246fa53..cdc9465e 100644 --- a/tests/testthat/test_ArchiveAsyncTuning.R +++ b/tests/testthat/test_ArchiveAsyncTuning.R @@ -12,6 +12,9 @@ test_that("ArchiveAsyncTuning access methods work", { terminator = trm("evals", n_evals = 20), store_benchmark_result = TRUE ) + + expect_benchmark_result(instance$archive$benchmark_result) + tuner = tnr("async_random_search") tuner$optimize(instance) diff --git a/tests/testthat/test_CallbackAsyncTuning.R b/tests/testthat/test_CallbackAsyncTuning.R index ffad721e..6170969e 100644 --- a/tests/testthat/test_CallbackAsyncTuning.R +++ b/tests/testthat/test_CallbackAsyncTuning.R @@ -346,3 +346,142 @@ test_that("on_result in TuningInstanceBatchMultiCrit works", { expect_equal(unique(instance$result$classif.ce), 0.7) }) +# stages in mlr3 workhorse ----------------------------------------------------- + +test_that("on_resample_begin works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async_tuning("test", + on_resample_begin = function(callback, context) { + # expect_* does not work + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + rush::rush_plan(n_workers = 2) + instance = tune( + tuner = tnr("async_random_search"), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msr("classif.ce"), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextAsyncTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_before_train works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async_tuning("test", + on_resample_before_train = function(callback, context) { + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + rush::rush_plan(n_workers = 2) + instance = tune( + tuner = tnr("async_random_search"), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msr("classif.ce"), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextAsyncTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_before_predict works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async_tuning("test", + on_resample_before_predict = function(callback, context) { + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + rush::rush_plan(n_workers = 2) + instance = tune( + tuner = tnr("async_random_search"), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msr("classif.ce"), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextAsyncTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_end works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async_tuning("test", + on_resample_end = function(callback, context) { + # expect_* does not work + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_class(context$pdatas$test, "PredictionData") + context$learner$state = mlr3misc::insert_named(context$learner$state, list(state_success = TRUE)) + context$data_extra = list(success = TRUE) + } + ) + + rush::rush_plan(n_workers = 2) + instance = tune( + tuner = tnr("async_random_search"), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msr("classif.ce"), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextAsyncTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) + + walk(instance$archive$benchmark_result$score()$learner, function(learner, ...) { + expect_true(learner$state$state_success) + }) +}) diff --git a/tests/testthat/test_CallbackBatchTuning.R b/tests/testthat/test_CallbackBatchTuning.R index 750a0943..4ab0fc92 100644 --- a/tests/testthat/test_CallbackBatchTuning.R +++ b/tests/testthat/test_CallbackBatchTuning.R @@ -258,4 +258,117 @@ test_that("on_result in TuningInstanceBatchMultiCrit works", { expect_equal(unique(instance$result$classif.ce), 0.7) }) +# stages in mlr3 workhorse ----------------------------------------------------- + +test_that("on_resample_begin works", { + + callback = callback_batch_tuning("test", + on_resample_begin = function(callback, context) { + # expect_* does not work + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + instance = tune( + tuner = tnr("random_search", batch_size = 1), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msrs(c("classif.ce", "classif.acc")), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextBatchTuning") + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_before_train works", { + callback = callback_batch_tuning("test", + on_resample_before_train = function(callback, context) { + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + instance = tune( + tuner = tnr("random_search", batch_size = 1), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msrs(c("classif.ce", "classif.acc")), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextBatchTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_before_predict works", { + callback = callback_batch_tuning("test", + on_resample_before_predict = function(callback, context) { + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_null(context$pdatas) + context$data_extra = list(success = TRUE) + } + ) + + instance = tune( + tuner = tnr("random_search", batch_size = 1), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msrs(c("classif.ce", "classif.acc")), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextBatchTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +}) + +test_that("on_resample_end works", { + callback = callback_batch_tuning("test", + on_resample_end = function(callback, context) { + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + checkmate::assert_number(context$iteration) + checkmate::assert_class(context$pdatas$test, "PredictionData") + context$data_extra = list(success = TRUE) + } + ) + + instance = tune( + tuner = tnr("random_search", batch_size = 1), + task = tsk("pima"), + learner = lrn("classif.rpart", minsplit = to_tune(1, 10)), + resampling = rsmp ("holdout"), + measures = msrs(c("classif.ce", "classif.acc")), + term_evals = 2, + callbacks = callback) + + expect_class(instance$objective$context, "ContextBatchTuning") + + walk(as.data.table(instance$archive$benchmark_result)$data_extra, function(data_extra) { + expect_true(data_extra$success) + }) +})