Skip to content

Commit 4504edf

Browse files
authored
feat: add resample stages to tuning callbacks (#479)
* feat: add evaluation callback stages * store_models * ... * ... * feat: save ArchiveAsyncTuning to a data.table with ArchiveAsyncTuningFrozen * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ...
1 parent f4d451e commit 4504edf

13 files changed

+484
-10
lines changed

DESCRIPTION

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ Suggests:
5050
rpart,
5151
testthat (>= 3.0.0),
5252
xgboost
53+
Remotes:
54+
mlr-org/mlr3,
55+
mlr-org/bbotk
5356
VignetteBuilder:
5457
knitr
5558
Config/testthat/edition: 3

R/ArchiveAsyncTuning.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ ArchiveAsyncTuning = R6Class("ArchiveAsyncTuning",
170170
# cache benchmark result
171171
if (self$rush$n_finished_tasks > private$.benchmark_result$n_resample_results) {
172172
bmrs = map(self$finished_data$resample_result, as_benchmark_result)
173-
init = BenchmarkResult$new()
174-
private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs, init = init)
173+
private$.benchmark_result = Reduce(function(lhs, rhs) lhs$combine(rhs), bmrs)
175174
}
176175
private$.benchmark_result
177176
}

R/CallbackAsyncTuning.R

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#' @title Create Asynchronous Tuning Callback
1+
#' @title Asynchronous Tuning Callback
22
#'
33
#' @description
44
#' Specialized [bbotk::CallbackAsync] for asynchronous tuning.
@@ -17,6 +17,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
1717
#' Called in `ObjectiveTuningAsync$eval()`.
1818
on_eval_after_xs = NULL,
1919

20+
#' @field on_resample_begin (`function()`)\cr
21+
#' Stage called at the beginning of an evaluation.
22+
#' Called in `workhorse()` (internal).
23+
on_resample_begin = NULL,
24+
25+
#' @field on_resample_before_train (`function()`)\cr
26+
#' Stage called before training the learner.
27+
#' Called in `workhorse()` (internal).
28+
on_resample_before_train = NULL,
29+
30+
#' @field on_resample_before_predict (`function()`)\cr
31+
#' Stage called before predicting.
32+
#' Called in `workhorse()` (internal).
33+
on_resample_before_predict = NULL,
34+
35+
#' @field on_resample_end (`function()`)\cr
36+
#' Stage called at the end of an evaluation.
37+
#' Called in `workhorse()` (internal).
38+
on_resample_end = NULL,
39+
2040
#' @field on_eval_after_resample (`function()`)\cr
2141
#' Stage called after hyperparameter configurations are evaluated.
2242
#' Called in `ObjectiveTuningAsync$eval()`.
@@ -52,6 +72,12 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
5272
#' - on_optimizer_before_eval
5373
#' Start Evaluation
5474
#' - on_eval_after_xs
75+
#' Start Resampling Iteration
76+
#' - on_resample_begin
77+
#' - on_resample_before_train
78+
#' - on_resample_before_predict
79+
#' - on_resample_end
80+
#' End Resampling Iteration
5581
#' - on_eval_after_resample
5682
#' - on_eval_before_archive
5783
#' End Evaluation
@@ -72,7 +98,7 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
7298
#' @details
7399
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
74100
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
75-
#' Tuning callbacks access [ContextAsyncTuning].
101+
#' Tuning callbacks access [ContextAsyncTuning] and [mlr3::ContextResample].
76102
#'
77103
#' @param id (`character(1)`)\cr
78104
#' Identifier for the new instance.
@@ -101,6 +127,26 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
101127
#' Called in `ObjectiveTuningAsync$eval()`.
102128
#' The functions must have two arguments named `callback` and `context`.
103129
#' The argument of `$.eval(xs)` is available in the `context`.
130+
#' @param on_resample_begin (`function()`)\cr
131+
#' Stage called at the beginning of a resampling iteration.
132+
#' Called in `workhorse()` (internal).
133+
#' See also [mlr3::callback_resample()].
134+
#' The functions must have two arguments named `callback` and `context`.
135+
#' @param on_resample_before_train (`function()`)\cr
136+
#' Stage called before training the learner.
137+
#' Called in `workhorse()` (internal).
138+
#' See also [mlr3::callback_resample()].
139+
#' The functions must have two arguments named `callback` and `context`.
140+
#' @param on_resample_before_predict (`function()`)\cr
141+
#' Stage called before predicting.
142+
#' Called in `workhorse()` (internal).
143+
#' See also [mlr3::callback_resample()].
144+
#' The functions must have two arguments named `callback` and `context`.
145+
#' @param on_resample_end (`function()`)\cr
146+
#' Stage called at the end of a resampling iteration.
147+
#' Called in `workhorse()` (internal).
148+
#' See also [mlr3::callback_resample()].
149+
#' The functions must have two arguments named `callback` and `context`.
104150
#' @param on_eval_after_resample (`function()`)\cr
105151
#' Stage called after a hyperparameter configuration is evaluated.
106152
#' Called in `ObjectiveTuningAsync$eval()`.
@@ -152,6 +198,10 @@ callback_async_tuning = function(
152198
on_worker_begin = NULL,
153199
on_optimizer_before_eval = NULL,
154200
on_eval_after_xs = NULL,
201+
on_resample_begin = NULL,
202+
on_resample_before_train = NULL,
203+
on_resample_before_predict = NULL,
204+
on_resample_end = NULL,
155205
on_eval_after_resample = NULL,
156206
on_eval_before_archive = NULL,
157207
on_optimizer_after_eval = NULL,
@@ -167,6 +217,10 @@ callback_async_tuning = function(
167217
on_worker_begin,
168218
on_optimizer_before_eval,
169219
on_eval_after_xs,
220+
on_resample_begin,
221+
on_resample_before_train,
222+
on_resample_before_predict,
223+
on_resample_end,
170224
on_eval_after_resample,
171225
on_eval_before_archive,
172226
on_optimizer_after_eval,
@@ -181,6 +235,10 @@ callback_async_tuning = function(
181235
"on_worker_begin",
182236
"on_optimizer_before_eval",
183237
"on_eval_after_xs",
238+
"on_resample_begin",
239+
"on_resample_before_train",
240+
"on_resample_before_predict",
241+
"on_resample_end",
184242
"on_eval_after_resample",
185243
"on_eval_before_archive",
186244
"on_optimizer_after_eval",

R/CallbackBatchTuning.R

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
2424
#' Called in `ObjectiveTuningBatch$eval_many()`.
2525
on_eval_after_design = NULL,
2626

27+
#' @field on_resample_begin (`function()`)\cr
28+
#' Stage called at the beginning of an evaluation.
29+
#' Called in `workhorse()` (internal).
30+
on_resample_begin = NULL,
31+
32+
#' @field on_resample_before_train (`function()`)\cr
33+
#' Stage called before training the learner.
34+
#' Called in `workhorse()` (internal).
35+
on_resample_before_train = NULL,
36+
37+
#' @field on_resample_before_predict (`function()`)\cr
38+
#' Stage called before predicting.
39+
#' Called in `workhorse()` (internal).
40+
on_resample_before_predict = NULL,
41+
42+
#' @field on_resample_end (`function()`)\cr
43+
#' Stage called at the end of an evaluation.
44+
#' Called in `workhorse()` (internal).
45+
on_resample_end = NULL,
46+
2747
#' @field on_eval_after_benchmark (`function()`)\cr
2848
#' Stage called after hyperparameter configurations are evaluated.
2949
#' Called in `ObjectiveTuningBatch$eval_many()`.
@@ -57,6 +77,12 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
5777
#' - on_optimizer_before_eval
5878
#' Start Evaluation
5979
#' - on_eval_after_design
80+
#' Start Resampling Iteration
81+
#' - on_resample_begin
82+
#' - on_resample_before_train
83+
#' - on_resample_before_predict
84+
#' - on_resample_end
85+
#' End Resampling Iteration
6086
#' - on_eval_after_benchmark
6187
#' - on_eval_before_archive
6288
#' End Evaluation
@@ -70,7 +96,7 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
7096
#' ```
7197
#'
7298
#' See also the section on parameters for more information on the stages.
73-
#' A tuning callback works with [ContextBatchTuning].
99+
#' A tuning callback works with [ContextBatchTuning] and [mlr3::ContextResample].
74100
#'
75101
#' @details
76102
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
@@ -100,6 +126,26 @@ CallbackBatchTuning= R6Class("CallbackBatchTuning",
100126
#' The functions must have two arguments named `callback` and `context`.
101127
#' The arguments of `$eval_many(xss, resampling)` are available in `context`.
102128
#' Additionally, the `design` is available in `context`.
129+
#' @param on_resample_begin (`function()`)\cr
130+
#' Stage called at the beginning of a resampling iteration.
131+
#' Called in `workhorse()` (internal).
132+
#' See also [mlr3::callback_resample()].
133+
#' The functions must have two arguments named `callback` and `context`.
134+
#' @param on_resample_before_train (`function()`)\cr
135+
#' Stage called before training the learner.
136+
#' Called in `workhorse()` (internal).
137+
#' See also [mlr3::callback_resample()].
138+
#' The functions must have two arguments named `callback` and `context`.
139+
#' @param on_resample_before_predict (`function()`)\cr
140+
#' Stage called before predicting.
141+
#' Called in `workhorse()` (internal).
142+
#' See also [mlr3::callback_resample()].
143+
#' The functions must have two arguments named `callback` and `context`.
144+
#' @param on_resample_end (`function()`)\cr
145+
#' Stage called at the end of a resampling iteration.
146+
#' Called in `workhorse()` (internal).
147+
#' See also [mlr3::callback_resample()].
148+
#' The functions must have two arguments named `callback` and `context`.
103149
#' @param on_eval_after_benchmark (`function()`)\cr
104150
#' Stage called after hyperparameter configurations are evaluated.
105151
#' Called in `ObjectiveTuningBatch$eval_many()`.
@@ -150,6 +196,10 @@ callback_batch_tuning = function(
150196
on_optimization_begin = NULL,
151197
on_optimizer_before_eval = NULL,
152198
on_eval_after_design = NULL,
199+
on_resample_begin = NULL,
200+
on_resample_before_train = NULL,
201+
on_resample_before_predict = NULL,
202+
on_resample_end = NULL,
153203
on_eval_after_benchmark = NULL,
154204
on_eval_before_archive = NULL,
155205
on_optimizer_after_eval = NULL,
@@ -163,6 +213,10 @@ callback_batch_tuning = function(
163213
on_optimization_begin,
164214
on_optimizer_before_eval,
165215
on_eval_after_design,
216+
on_resample_begin,
217+
on_resample_before_train,
218+
on_resample_before_predict,
219+
on_resample_end,
166220
on_eval_after_benchmark,
167221
on_eval_before_archive,
168222
on_optimizer_after_eval,
@@ -175,6 +229,10 @@ callback_batch_tuning = function(
175229
"on_optimization_begin",
176230
"on_optimizer_before_eval",
177231
"on_eval_after_design",
232+
"on_resample_begin",
233+
"on_resample_before_train",
234+
"on_resample_before_predict",
235+
"on_resample_end",
178236
"on_eval_after_benchmark",
179237
"on_eval_before_archive",
180238
"on_optimizer_after_eval",

R/ObjectiveTuningAsync.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ObjectiveTuningAsync = R6Class("ObjectiveTuningAsync",
3131
lg$debug("Resampling hyperparameter configuration")
3232

3333
# resample hyperparameter configuration
34-
private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0))
34+
private$.resample_result = resample(self$task, self$learner, self$resampling, store_models = self$store_models, allow_hotstart = TRUE, clone = character(0), callbacks = self$callbacks)
3535
call_back("on_eval_after_resample", self$callbacks, self$context)
3636

3737
lg$debug("Aggregating performance")

R/ObjectiveTuningBatch.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
7474
private$.benchmark_result = benchmark(
7575
design = private$.design,
7676
store_models = self$store_models,
77-
clone = character(0))
77+
clone = character(0),
78+
callbacks = self$callbacks)
7879
call_back("on_eval_after_benchmark", self$callbacks, self$context)
7980

8081
# aggregate performance scores

man/CallbackAsyncTuning.Rd

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/CallbackBatchTuning.Rd

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/callback_async_tuning.Rd

Lines changed: 35 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)