Skip to content

Commit

Permalink
async
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 8, 2024
1 parent a3e0778 commit 9f5146b
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 10 deletions.
4 changes: 1 addition & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Depends:
paradox (>= 1.0.1),
R (>= 3.1.0)
Imports:
bbotk (>= 1.2.0),
bbotk (>= 1.3.0),
checkmate (>= 2.0.0),
data.table,
lgr,
Expand All @@ -50,8 +50,6 @@ Suggests:
rpart,
testthat (>= 3.0.0),
xgboost
Remotes:
mlr-org/bbotk
VignetteBuilder:
knitr
Config/testthat/edition: 3
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ export(TuningInstanceSingleCrit)
export(as_search_space)
export(as_tuner)
export(as_tuners)
export(assert_async_tuning_callback)
export(assert_async_tuning_callbacks)
export(assert_batch_tuning_callback)
export(assert_batch_tuning_callbacks)
export(assert_tuner)
export(assert_tuner_async)
export(assert_tuner_batch)
Expand Down
25 changes: 25 additions & 0 deletions R/CallbackAsyncTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,28 @@ callback_async_tuning = function(
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
}


#' @title Assertions for Callbacks
#'
#' @description
#' Assertions for [CallbackAsyncTuning] class.
#'
#' @param callback ([CallbackAsyncTuning]).
#' @param null_ok (`logical(1)`)\cr
#' If `TRUE`, `NULL` is allowed.
#'
#' @return [CallbackAsyncTuning | List of [CallbackAsyncTuning]s.
#' @export
assert_async_tuning_callback = function(callback, null_ok = FALSE) {
if (null_ok && is.null(callback)) return(invisible(NULL))
assert_class(callback, "CallbackAsyncTuning")
invisible(callback)
}

#' @export
#' @param callbacks (list of [CallbackAsyncTuning]).
#' @rdname assert_async_tuning_callback
assert_async_tuning_callbacks = function(callbacks) {
invisible(lapply(callbacks, assert_callback))
}
24 changes: 24 additions & 0 deletions R/CallbackBatchTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,27 @@ callback_batch_tuning = function(
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
}

#' @title Assertions for Callbacks
#'
#' @description
#' Assertions for [CallbackBatchTuning] class.
#'
#' @param callback ([CallbackBatchTuning]).
#' @param null_ok (`logical(1)`)\cr
#' If `TRUE`, `NULL` is allowed.
#'
#' @return [CallbackBatchTuning | List of [CallbackBatchTuning]s.
#' @export
assert_batch_tuning_callback = function(callback, null_ok = FALSE) {
if (null_ok && is.null(callback)) return(invisible(NULL))
assert_class(callback, "CallbackBatchTuning")
invisible(callback)
}

#' @export
#' @param callbacks (list of [CallbackBatchTuning]).
#' @rdname assert_batch_tuning_callback
assert_batch_tuning_callbacks = function(callbacks) {
invisible(lapply(callbacks, assert_callback))
}
1 change: 1 addition & 0 deletions R/TuningInstanceAsyncMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
) {
require_namespaces("rush")
learner = assert_learner(as_learner(learner, clone = TRUE))
callbacks = assert_async_tuning_callback(callbacks)

# tune token and search space
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {
Expand Down
1 change: 1 addition & 0 deletions R/TuningInstanceAsyncSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ TuningInstanceAsyncSingleCrit = R6Class("TuningInstanceAsyncSingleCrit",
) {
require_namespaces("rush")
learner = assert_learner(as_learner(learner, clone = TRUE))
callbacks = assert_async_tuning_callback(callbacks)

# tune token and search space
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {
Expand Down
1 change: 1 addition & 0 deletions R/TuningInstanceBatchMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ TuningInstanceBatchMultiCrit = R6Class("TuningInstanceBatchMultiCrit",
callbacks = NULL
) {
learner = assert_learner(as_learner(learner, clone = TRUE))
callbacks = assert_batch_tuning_callback(callbacks)

# tune token and search space
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {
Expand Down
1 change: 1 addition & 0 deletions R/TuningInstanceBatchSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ TuningInstanceBatchSingleCrit = R6Class("TuningInstanceBatchSingleCrit",
callbacks = NULL
) {
learner = assert_learner(as_learner(learner, clone = TRUE))
callbacks = assert_batch_tuning_callback(callbacks)

# tune token and search space
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {
Expand Down
60 changes: 57 additions & 3 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ load_callback_async_save_logs = function() {
#' @name mlr3tuning.one_se_rule
#'
#' @description
#' Selects the hyperparameter configuration with the smallest feature set within one standard error of the best.
#' The learner must support `$selected_features()`.
#' The one standard error rule takes the number of features into account when selecting the best hyperparameter configuration.
#' Many learners support internal feature selection, which can be accessed via `$selected_features()`.
#' The callback selects the hyperparameter configuration with the smallest feature set within one standard error of the best performing configuration.
#' If there are multiple such hyperparameter configurations with the same number of features, the first one is selected.
#' If the configurations have exactly the same performance but different number of features, the one with the smallest number of features is selected.
#'
#' @source
#' `r format_bib("kuhn2013")`
Expand All @@ -326,12 +326,66 @@ load_callback_async_save_logs = function() {
#' instance$result
NULL

load_callback_async_one_se_rule = function() {
callback_async_tuning("mlr3tuning.async_one_se_rule",
label = "One Standard Error Rule Callback",
man = "mlr3tuning::mlr3tuning.one_se_rule",

on_optimization_begin = function(callback, context) {
if ("selected_features" %nin% context$instance$objective$learner$properties) {
stopf("Learner '%s' does not support `$selected_features()`", context$instance$objective$learner$id)
}
callback$state$store_models = context$instance$objective$store_models
context$instance$objective$store_models = TRUE
},

on_eval_before_archive = function(callback, context) {
res = context$resample_result$aggregate(msr("selected_features"))
context$aggregated_performance$n_features = res
if (!callback$state$store_models) {
context$resample_result$discard(models = TRUE)
}
},

on_tuning_result_begin = function(callback, context) {
archive = context$instance$archive
data = as.data.table(archive)

# standard error
y = data[[archive$cols_y]]
se = sd(y) / sqrt(length(y))

if (se == 0) {
# select smallest future set when all scores are the same
best = data[which.min(get("n_features"))]
} else {
# select smallest future set within one standard error of the best
best_y = context$result_y
best = data[y > best_y - se & y < best_y + se, ][which.min(get("n_features"))]
}

cols_x = context$instance$archive$cols_x
cols_y = context$instance$archive$cols_y

context$result_xdt = best[, c(cols_x, "n_features"), with = FALSE]
context$result_extra = best[, !c(cols_x, cols_y), with = FALSE]
context$result_y = unlist(best[, cols_y, with = FALSE])

context$instance$objective$store_models = callback$state$store_models
}
)
}


load_callback_one_se_rule = function() {
callback_batch_tuning("mlr3tuning.one_se_rule",
label = "One Standard Error Rule Callback",
man = "mlr3tuning::mlr3tuning.one_se_rule",

on_optimization_begin = function(callback, context) {
if ("selected_features" %nin% context$instance$objective$learner$properties) {
stopf("Learner '%s' does not support `$selected_features()`", context$instance$objective$learner$id)
}
callback$state$store_models = context$instance$objective$store_models
context$instance$objective$store_models = TRUE
},
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
x$add("mlr3tuning.async_measures", load_callback_async_measures)
x$add("mlr3tuning.async_mlflow", load_callback_async_mlflow)
x$add("mlr3tuning.async_save_logs", load_callback_async_save_logs)
x$add("mlr3tuning.async_one_se_rule", load_callback_async_one_se_rule)
x$add("mlr3tuning.backup", load_callback_backup)
x$add("mlr3tuning.default_configuration", load_callback_default_configuration)
x$add("mlr3tuning.measures", load_callback_measures)
Expand Down
25 changes: 25 additions & 0 deletions man/assert_async_tuning_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/assert_batch_tuning_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/mlr3tuning.one_se_rule.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 21 additions & 1 deletion tests/testthat/test_mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,27 @@ test_that("one se rule callback works", {
)

expect_numeric(instance$archive$data$n_features)
instance$result
expect_numeric(instance$result$n_features)
})

test_that("one se rule callback works", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush::rush_plan(n_workers = 2)
instance = ti_async(
task = tsk("pima"),
learner = lrn("classif.rpart", cp = to_tune(1e-04, 1e-1)),
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 5),
callbacks = clbk("mlr3tuning.async_one_se_rule")
)

tuner = tnr("async_random_search")
tuner$optimize(instance)

expect_numeric(instance$archive$data$n_features)
expect_numeric(instance$result$n_features)
})

0 comments on commit 9f5146b

Please sign in to comment.