Skip to content

Commit

Permalink
feat: add resample stages to tuning callbacks (#479)
Browse files Browse the repository at this point in the history
* feat: add evaluation callback stages

* store_models

* ...

* ...

* feat: save ArchiveAsyncTuning to a data.table with ArchiveAsyncTuningFrozen

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...

* ...
  • Loading branch information
be-marc authored Feb 11, 2025
1 parent f4d451e commit 4504edf
Show file tree
Hide file tree
Showing 13 changed files with 484 additions and 10 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Suggests:
rpart,
testthat (>= 3.0.0),
xgboost
Remotes:
mlr-org/mlr3,
mlr-org/bbotk
VignetteBuilder:
knitr
Config/testthat/edition: 3
Expand Down
3 changes: 1 addition & 2 deletions R/ArchiveAsyncTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ ArchiveAsyncTuning = R6Class("ArchiveAsyncTuning",
# cache benchmark result
if (self$rush$n_finished_tasks > private$.benchmark_result$n_resample_results) {
bmrs = map(self$finished_data$resample_result, as_benchmark_result)
init = BenchmarkResult$new()
private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs, init = init)
private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs)
}
private$.benchmark_result
}
Expand Down
62 changes: 60 additions & 2 deletions R/CallbackAsyncTuning.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Create Asynchronous Tuning Callback
#' @title Asynchronous Tuning Callback
#'
#' @description
#' Specialized [bbotk::CallbackAsync] for asynchronous tuning.
Expand All @@ -17,6 +17,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' Called in `ObjectiveTuningAsync$eval()`.
on_eval_after_xs = NULL,

#' @field on_resample_begin (`function()`)\cr
#' Stage called at the beginning of an evaluation.
#' Called in `workhorse()` (internal).
on_resample_begin = NULL,

#' @field on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
on_resample_before_train = NULL,

#' @field on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
on_resample_before_predict = NULL,

#' @field on_resample_end (`function()`)\cr
#' Stage called at the end of an evaluation.
#' Called in `workhorse()` (internal).
on_resample_end = NULL,

#' @field on_eval_after_resample (`function()`)\cr
#' Stage called after hyperparameter configurations are evaluated.
#' Called in `ObjectiveTuningAsync$eval()`.
Expand Down Expand Up @@ -52,6 +72,12 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' - on_optimizer_before_eval
#' Start Evaluation
#' - on_eval_after_xs
#' Start Resampling Iteration
#' - on_resample_begin
#' - on_resample_before_train
#' - on_resample_before_predict
#' - on_resample_end
#' End Resampling Iteration
#' - on_eval_after_resample
#' - on_eval_before_archive
#' End Evaluation
Expand All @@ -72,7 +98,7 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
#' Tuning callbacks access [ContextAsyncTuning].
#' Tuning callbacks access [ContextAsyncTuning] and [mlr3::ContextResample].
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
Expand Down Expand Up @@ -101,6 +127,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' Called in `ObjectiveTuningAsync$eval()`.
#' The functions must have two arguments named `callback` and `context`.
#' The argument of `$.eval(xs)` is available in the `context`.
#' @param on_resample_begin (`function()`)\cr
#' Stage called at the beginning of a resampling iteration.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_end (`function()`)\cr
#' Stage called at the end of a resampling iteration.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_eval_after_resample (`function()`)\cr
#' Stage called after a hyperparameter configuration is evaluated.
#' Called in `ObjectiveTuningAsync$eval()`.
Expand Down Expand Up @@ -152,6 +198,10 @@ callback_async_tuning = function(
on_worker_begin = NULL,
on_optimizer_before_eval = NULL,
on_eval_after_xs = NULL,
on_resample_begin = NULL,
on_resample_before_train = NULL,
on_resample_before_predict = NULL,
on_resample_end = NULL,
on_eval_after_resample = NULL,
on_eval_before_archive = NULL,
on_optimizer_after_eval = NULL,
Expand All @@ -167,6 +217,10 @@ callback_async_tuning = function(
on_worker_begin,
on_optimizer_before_eval,
on_eval_after_xs,
on_resample_begin,
on_resample_before_train,
on_resample_before_predict,
on_resample_end,
on_eval_after_resample,
on_eval_before_archive,
on_optimizer_after_eval,
Expand All @@ -181,6 +235,10 @@ callback_async_tuning = function(
"on_worker_begin",
"on_optimizer_before_eval",
"on_eval_after_xs",
"on_resample_begin",
"on_resample_before_train",
"on_resample_before_predict",
"on_resample_end",
"on_eval_after_resample",
"on_eval_before_archive",
"on_optimizer_after_eval",
Expand Down
60 changes: 59 additions & 1 deletion R/CallbackBatchTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
#' Called in `ObjectiveTuningBatch$eval_many()`.
on_eval_after_design = NULL,

#' @field on_resample_begin (`function()`)\cr
#' Stage called at the beginning of an evaluation.
#' Called in `workhorse()` (internal).
on_resample_begin = NULL,

#' @field on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
on_resample_before_train = NULL,

#' @field on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
on_resample_before_predict = NULL,

#' @field on_resample_end (`function()`)\cr
#' Stage called at the end of an evaluation.
#' Called in `workhorse()` (internal).
on_resample_end = NULL,

#' @field on_eval_after_benchmark (`function()`)\cr
#' Stage called after hyperparameter configurations are evaluated.
#' Called in `ObjectiveTuningBatch$eval_many()`.
Expand Down Expand Up @@ -57,6 +77,12 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
#' - on_optimizer_before_eval
#' Start Evaluation
#' - on_eval_after_design
#' Start Resampling Iteration
#' - on_resample_begin
#' - on_resample_before_train
#' - on_resample_before_predict
#' - on_resample_end
#' End Resampling Iteration
#' - on_eval_after_benchmark
#' - on_eval_before_archive
#' End Evaluation
Expand All @@ -70,7 +96,7 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
#' ```
#'
#' See also the section on parameters for more information on the stages.
#' A tuning callback works with [ContextBatchTuning].
#' A tuning callback works with [ContextBatchTuning] and [mlr3::ContextResample].
#'
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
Expand Down Expand Up @@ -100,6 +126,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
#' The functions must have two arguments named `callback` and `context`.
#' The arguments of `$eval_many(xss, resampling)` are available in `context`.
#' Additionally, the `design` is available in `context`.
#' @param on_resample_begin (`function()`)\cr
#' Stage called at the beginning of a resampling iteration.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_resample_end (`function()`)\cr
#' Stage called at the end of a resampling iteration.
#' Called in `workhorse()` (internal).
#' See also [mlr3::callback_resample()].
#' The functions must have two arguments named `callback` and `context`.
#' @param on_eval_after_benchmark (`function()`)\cr
#' Stage called after hyperparameter configurations are evaluated.
#' Called in `ObjectiveTuningBatch$eval_many()`.
Expand Down Expand Up @@ -150,6 +196,10 @@ callback_batch_tuning = function(
on_optimization_begin = NULL,
on_optimizer_before_eval = NULL,
on_eval_after_design = NULL,
on_resample_begin = NULL,
on_resample_before_train = NULL,
on_resample_before_predict = NULL,
on_resample_end = NULL,
on_eval_after_benchmark = NULL,
on_eval_before_archive = NULL,
on_optimizer_after_eval = NULL,
Expand All @@ -163,6 +213,10 @@ callback_batch_tuning = function(
on_optimization_begin,
on_optimizer_before_eval,
on_eval_after_design,
on_resample_begin,
on_resample_before_train,
on_resample_before_predict,
on_resample_end,
on_eval_after_benchmark,
on_eval_before_archive,
on_optimizer_after_eval,
Expand All @@ -175,6 +229,10 @@ callback_batch_tuning = function(
"on_optimization_begin",
"on_optimizer_before_eval",
"on_eval_after_design",
"on_resample_begin",
"on_resample_before_train",
"on_resample_before_predict",
"on_resample_end",
"on_eval_after_benchmark",
"on_eval_before_archive",
"on_optimizer_after_eval",
Expand Down
2 changes: 1 addition & 1 deletion R/ObjectiveTuningAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ObjectiveTuningAsync = R6Class("ObjectiveTuningAsync",
lg$debug("Resampling hyperparameter configuration")

# resample hyperparameter configuration
private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0))
private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0), callbacks = self$callbacks)
call_back("on_eval_after_resample", self$callbacks, self$context)

lg$debug("Aggregating performance")
Expand Down
3 changes: 2 additions & 1 deletion R/ObjectiveTuningBatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
private$.benchmark_result = benchmark(
design = private$.design,
store_models = self$store_models,
clone = character(0))
clone = character(0),
callbacks = self$callbacks)
call_back("on_eval_after_benchmark", self$callbacks, self$context)

# aggregate performance scores
Expand Down
18 changes: 17 additions & 1 deletion man/CallbackAsyncTuning.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/CallbackBatchTuning.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 35 additions & 1 deletion man/callback_async_tuning.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4504edf

Please sign in to comment.