Skip to content

Commit 9f5146b

Browse files
committed
async
1 parent a3e0778 commit 9f5146b

14 files changed

+190
-10
lines changed

DESCRIPTION

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Depends:
2929
paradox (>= 1.0.1),
3030
R (>= 3.1.0)
3131
Imports:
32-
bbotk (>= 1.2.0),
32+
bbotk (>= 1.3.0),
3333
checkmate (>= 2.0.0),
3434
data.table,
3535
lgr,
@@ -50,8 +50,6 @@ Suggests:
5050
rpart,
5151
testthat (>= 3.0.0),
5252
xgboost
53-
Remotes:
54-
mlr-org/bbotk
5553
VignetteBuilder:
5654
knitr
5755
Config/testthat/edition: 3

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ export(TuningInstanceSingleCrit)
5050
export(as_search_space)
5151
export(as_tuner)
5252
export(as_tuners)
53+
export(assert_async_tuning_callback)
54+
export(assert_async_tuning_callbacks)
55+
export(assert_batch_tuning_callback)
56+
export(assert_batch_tuning_callbacks)
5357
export(assert_tuner)
5458
export(assert_tuner_async)
5559
export(assert_tuner_batch)

R/CallbackAsyncTuning.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,28 @@ callback_async_tuning = function(
202202
iwalk(stages, function(stage, name) callback[[name]] = stage)
203203
callback
204204
}
205+
206+
207+
#' @title Assertions for Callbacks
208+
#'
209+
#' @description
210+
#' Assertions for [CallbackAsyncTuning] class.
211+
#'
212+
#' @param callback ([CallbackAsyncTuning]).
213+
#' @param null_ok (`logical(1)`)\cr
214+
#' If `TRUE`, `NULL` is allowed.
215+
#'
216+
#' @return [CallbackAsyncTuning | List of [CallbackAsyncTuning]s.
217+
#' @export
218+
assert_async_tuning_callback = function(callback, null_ok = FALSE) {
219+
if (null_ok && is.null(callback)) return(invisible(NULL))
220+
assert_class(callback, "CallbackAsyncTuning")
221+
invisible(callback)
222+
}
223+
224+
#' @export
225+
#' @param callbacks (list of [CallbackAsyncTuning]).
226+
#' @rdname assert_async_tuning_callback
227+
assert_async_tuning_callbacks = function(callbacks) {
228+
invisible(lapply(callbacks, assert_callback))
229+
}

R/CallbackBatchTuning.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,27 @@ callback_batch_tuning = function(
195195
iwalk(stages, function(stage, name) callback[[name]] = stage)
196196
callback
197197
}
198+
199+
#' @title Assertions for Callbacks
200+
#'
201+
#' @description
202+
#' Assertions for [CallbackBatchTuning] class.
203+
#'
204+
#' @param callback ([CallbackBatchTuning]).
205+
#' @param null_ok (`logical(1)`)\cr
206+
#' If `TRUE`, `NULL` is allowed.
207+
#'
208+
#' @return [CallbackBatchTuning | List of [CallbackBatchTuning]s.
209+
#' @export
210+
assert_batch_tuning_callback = function(callback, null_ok = FALSE) {
211+
if (null_ok && is.null(callback)) return(invisible(NULL))
212+
assert_class(callback, "CallbackBatchTuning")
213+
invisible(callback)
214+
}
215+
216+
#' @export
217+
#' @param callbacks (list of [CallbackBatchTuning]).
218+
#' @rdname assert_batch_tuning_callback
219+
assert_batch_tuning_callbacks = function(callbacks) {
220+
invisible(lapply(callbacks, assert_callback))
221+
}

R/TuningInstanceAsyncMulticrit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
5555
) {
5656
require_namespaces("rush")
5757
learner = assert_learner(as_learner(learner, clone = TRUE))
58+
callbacks = assert_async_tuning_callback(callbacks)
5859

5960
# tune token and search space
6061
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {

R/TuningInstanceAsyncSingleCrit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ TuningInstanceAsyncSingleCrit = R6Class("TuningInstanceAsyncSingleCrit",
6565
) {
6666
require_namespaces("rush")
6767
learner = assert_learner(as_learner(learner, clone = TRUE))
68+
callbacks = assert_async_tuning_callback(callbacks)
6869

6970
# tune token and search space
7071
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {

R/TuningInstanceBatchMulticrit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ TuningInstanceBatchMultiCrit = R6Class("TuningInstanceBatchMultiCrit",
9090
callbacks = NULL
9191
) {
9292
learner = assert_learner(as_learner(learner, clone = TRUE))
93+
callbacks = assert_batch_tuning_callback(callbacks)
9394

9495
# tune token and search space
9596
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {

R/TuningInstanceBatchSingleCrit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ TuningInstanceBatchSingleCrit = R6Class("TuningInstanceBatchSingleCrit",
128128
callbacks = NULL
129129
) {
130130
learner = assert_learner(as_learner(learner, clone = TRUE))
131+
callbacks = assert_batch_tuning_callback(callbacks)
131132

132133
# tune token and search space
133134
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {

R/mlr_callbacks.R

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ load_callback_async_save_logs = function() {
300300
#' @name mlr3tuning.one_se_rule
301301
#'
302302
#' @description
303-
#' Selects the hyperparameter configuration with the smallest feature set within one standard error of the best.
304-
#' The learner must support `$selected_features()`.
303+
#' The one standard error rule takes the number of features into account when selecting the best hyperparameter configuration.
304+
#' Many learners support internal feature selection, which can be accessed via `$selected_features()`.
305+
#' The callback selects the hyperparameter configuration with the smallest feature set within one standard error of the best performing configuration.
305306
#' If there are multiple such hyperparameter configurations with the same number of features, the first one is selected.
306-
#' If the configurations have exactly the same performance but different number of features, the one with the smallest number of features is selected.
307307
#'
308308
#' @source
309309
#' `r format_bib("kuhn2013")`
@@ -326,12 +326,66 @@ load_callback_async_save_logs = function() {
326326
#' instance$result
327327
NULL
328328

329+
load_callback_async_one_se_rule = function() {
330+
callback_async_tuning("mlr3tuning.async_one_se_rule",
331+
label = "One Standard Error Rule Callback",
332+
man = "mlr3tuning::mlr3tuning.one_se_rule",
333+
334+
on_optimization_begin = function(callback, context) {
335+
if ("selected_features" %nin% context$instance$objective$learner$properties) {
336+
stopf("Learner '%s' does not support `$selected_features()`", context$instance$objective$learner$id)
337+
}
338+
callback$state$store_models = context$instance$objective$store_models
339+
context$instance$objective$store_models = TRUE
340+
},
341+
342+
on_eval_before_archive = function(callback, context) {
343+
res = context$resample_result$aggregate(msr("selected_features"))
344+
context$aggregated_performance$n_features = res
345+
if (!callback$state$store_models) {
346+
context$resample_result$discard(models = TRUE)
347+
}
348+
},
349+
350+
on_tuning_result_begin = function(callback, context) {
351+
archive = context$instance$archive
352+
data = as.data.table(archive)
353+
354+
# standard error
355+
y = data[[archive$cols_y]]
356+
se = sd(y) / sqrt(length(y))
357+
358+
if (se == 0) {
359+
# select smallest future set when all scores are the same
360+
best = data[which.min(get("n_features"))]
361+
} else {
362+
# select smallest future set within one standard error of the best
363+
best_y = context$result_y
364+
best = data[y > best_y - se & y < best_y + se, ][which.min(get("n_features"))]
365+
}
366+
367+
cols_x = context$instance$archive$cols_x
368+
cols_y = context$instance$archive$cols_y
369+
370+
context$result_xdt = best[, c(cols_x, "n_features"), with = FALSE]
371+
context$result_extra = best[, !c(cols_x, cols_y), with = FALSE]
372+
context$result_y = unlist(best[, cols_y, with = FALSE])
373+
374+
context$instance$objective$store_models = callback$state$store_models
375+
}
376+
)
377+
}
378+
379+
329380
load_callback_one_se_rule = function() {
330381
callback_batch_tuning("mlr3tuning.one_se_rule",
331382
label = "One Standard Error Rule Callback",
332383
man = "mlr3tuning::mlr3tuning.one_se_rule",
333384

334385
on_optimization_begin = function(callback, context) {
386+
if ("selected_features" %nin% context$instance$objective$learner$properties) {
387+
stopf("Learner '%s' does not support `$selected_features()`", context$instance$objective$learner$id)
388+
}
335389
callback$state$store_models = context$instance$objective$store_models
336390
context$instance$objective$store_models = TRUE
337391
},

R/zzz.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
x$add("mlr3tuning.async_measures", load_callback_async_measures)
2121
x$add("mlr3tuning.async_mlflow", load_callback_async_mlflow)
2222
x$add("mlr3tuning.async_save_logs", load_callback_async_save_logs)
23+
x$add("mlr3tuning.async_one_se_rule", load_callback_async_one_se_rule)
2324
x$add("mlr3tuning.backup", load_callback_backup)
2425
x$add("mlr3tuning.default_configuration", load_callback_default_configuration)
2526
x$add("mlr3tuning.measures", load_callback_measures)

man/assert_async_tuning_callback.Rd

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/assert_batch_tuning_callback.Rd

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr3tuning.one_se_rule.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mlr_callbacks.R

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,27 @@ test_that("one se rule callback works", {
427427
)
428428

429429
expect_numeric(instance$archive$data$n_features)
430-
instance$result
430+
expect_numeric(instance$result$n_features)
431+
})
432+
433+
test_that("one se rule callback works", {
434+
skip_on_cran()
435+
skip_if_not_installed("rush")
436+
flush_redis()
437+
438+
rush::rush_plan(n_workers = 2)
439+
instance = ti_async(
440+
task = tsk("pima"),
441+
learner = lrn("classif.rpart", cp = to_tune(1e-04, 1e-1)),
442+
resampling = rsmp("cv", folds = 3),
443+
measures = msr("classif.ce"),
444+
terminator = trm("evals", n_evals = 5),
445+
callbacks = clbk("mlr3tuning.async_one_se_rule")
446+
)
431447

448+
tuner = tnr("async_random_search")
449+
tuner$optimize(instance)
432450

451+
expect_numeric(instance$archive$data$n_features)
452+
expect_numeric(instance$result$n_features)
433453
})

0 commit comments

Comments
 (0)