Skip to content

Commit

Permalink
feat: add one se rule callback (#464)
Browse files Browse the repository at this point in the history
* feat: add one se rule callback

* ...

* ...

* ...

* ...

* ...

* ...

* async

* ...

* ...

* ...

* ...

* ...
  • Loading branch information
be-marc authored Nov 8, 2024
1 parent 4dc1160 commit ee1f2a4
Show file tree
Hide file tree
Showing 27 changed files with 1,355 additions and 162 deletions.
4 changes: 1 addition & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Depends:
paradox (>= 1.0.1),
R (>= 3.1.0)
Imports:
bbotk (>= 1.2.0),
bbotk (>= 1.3.0),
checkmate (>= 2.0.0),
data.table,
lgr,
Expand All @@ -50,8 +50,6 @@ Suggests:
rpart,
testthat (>= 3.0.0),
xgboost
Remotes:
mlr-org/bbotk
VignetteBuilder:
knitr
Config/testthat/edition: 3
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ export(TuningInstanceSingleCrit)
export(as_search_space)
export(as_tuner)
export(as_tuners)
export(assert_async_tuning_callback)
export(assert_async_tuning_callbacks)
export(assert_batch_tuning_callback)
export(assert_batch_tuning_callbacks)
export(assert_tuner)
export(assert_tuner_async)
export(assert_tuner_batch)
Expand Down Expand Up @@ -88,5 +92,6 @@ importFrom(bbotk,trms)
importFrom(mlr3misc,clbk)
importFrom(mlr3misc,clbks)
importFrom(mlr3misc,mlr_callbacks)
importFrom(stats,sd)
importFrom(utils,bibentry)
importFrom(utils,tail)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# mlr3tuning (development version)

* feat: Add new callback `clbk("mlr3tuning.one_se_rule")` that selects the the hyperparameter configuration with the smallest feature set within one standard error of the best.
* feat: Add new stages `on_tuning_result_begin` and `on_result_begin` to `CallbackAsyncTuning` and `CallbackBatchTuning`.
* refactor: Rename stage `on_result` to `on_result_end` in `CallbackAsyncTuning` and `CallbackBatchTuning`.
* docs: Extend the `CallbackAsyncTuning` and `CallbackBatchTuning` documentation.
* compatibility: mlr3 0.22.0

# mlr3tuning 1.1.0
Expand Down
139 changes: 109 additions & 30 deletions R/CallbackAsyncTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
public = list(

#' @field on_eval_after_xs (`function()`)\cr
#' Stage called after xs is passed.
#' Called in `ObjectiveTuning$eval()`.
#' Stage called after xs is passed.
#' Called in `ObjectiveTuningAsync$eval()`.
on_eval_after_xs = NULL,

#' @field on_eval_after_resample (`function()`)\cr
#' Stage called after hyperparameter configurations are evaluated.
#' Called in `ObjectiveTuning$eval()`.
#' Stage called after hyperparameter configurations are evaluated.
#' Called in `ObjectiveTuningAsync$eval()`.
on_eval_after_resample = NULL,

#' @field on_eval_before_archive (`function()`)\cr
#' Stage called before performance values are written to the archive.
#' Called in `ObjectiveTuning$eval()`.
on_eval_before_archive = NULL
#' Stage called before performance values are written to the archive.
#' Called in `ObjectiveTuningAsync$eval()`.
on_eval_before_archive = NULL,

#' @field on_tuning_result_begin (`function()`)\cr
#' Stage called before the results are written.
#' Called in `TuningInstance*$assign_result()`.
on_tuning_result_begin = NULL
)
)

Expand Down Expand Up @@ -54,7 +59,9 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' End Optimization on Worker
#' - on_worker_end
#' End Worker
#' - on_result
#' - on_tuning_result_begin
#' - on_result_begin
#' - on_result_end
#' - on_optimization_end
#' End Tuning
#' ```
Expand All @@ -68,39 +75,70 @@ CallbackAsyncTuning = R6Class("CallbackAsyncTuning",
#' Tuning callbacks access [ContextAsyncTuning].
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
#' Identifier for the new instance.
#' @param label (`character(1)`)\cr
#' Label for the new instance.
#' Label for the new instance.
#' @param man (`character(1)`)\cr
#' String in the format `[pkg]::[topic]` pointing to a manual page for this object.
#' The referenced help package can be opened via method `$help()`.
#' String in the format `[pkg]::[topic]` pointing to a manual page for this object.
#' The referenced help package can be opened via method `$help()`.
#'
#' @param on_optimization_begin (`function()`)\cr
#' Stage called at the beginning of the optimization.
#' Called in `Optimizer$optimize()`.
#' Stage called at the beginning of the optimization.
#' Called in `Optimizer$optimize()`.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_worker_begin (`function()`)\cr
#' Stage called at the beginning of the optimization on the worker.
#' Called in the worker loop.
#' Stage called at the beginning of the optimization on the worker.
#' Called in the worker loop.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_optimizer_before_eval (`function()`)\cr
#' Stage called after the optimizer proposes points.
#' Called in `OptimInstance$.eval_point()`.
#' Stage called after the optimizer proposes points.
#' Called in `OptimInstance$.eval_point()`.
#' The functions must have two arguments named `callback` and `context`.
#' The argument of `instance$.eval_point(xs)` and `xs_trafoed` and `extra` are available in the `context`.
#' Or `xs` and `xs_trafoed` of `instance$.eval_queue()` are available in the `context`.
#' @param on_eval_after_xs (`function()`)\cr
#' Stage called after xs is passed.
#' Called in `ObjectiveTuning$eval()`.
#' Stage called after xs is passed to the objective.
#' Called in `ObjectiveTuningAsync$eval()`.
#' The functions must have two arguments named `callback` and `context`.
#' The argument of `$.eval(xs)` is available in the `context`.
#' @param on_eval_after_resample (`function()`)\cr
#' Stage called after a hyperparameter configuration is evaluated.
#' Called in `ObjectiveTuning$eval()`.
#' Stage called after a hyperparameter configuration is evaluated.
#' Called in `ObjectiveTuningAsync$eval()`.
#' The functions must have two arguments named `callback` and `context`.
#' The `resample_result` is available in the `context
#' @param on_eval_before_archive (`function()`)\cr
#' Stage called before performance values are written to the archive.
#' Called in `ObjectiveTuning$eval()`.
#' Stage called before performance values are written to the archive.
#' Called in `ObjectiveTuningAsync$eval()`.
#' The functions must have two arguments named `callback` and `context`.
#' The `aggregated_performance` is available in `context`.
#' @param on_optimizer_after_eval (`function()`)\cr
#' Stage called after points are evaluated.
#' Called in `OptimInstance$.eval_point()`.
#' Stage called after points are evaluated.
#' Called in `OptimInstance$.eval_point()`.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_worker_end (`function()`)\cr
#' Stage called at the end of the optimization on the worker.
#' Called in the worker loop.
#' Stage called at the end of the optimization on the worker.
#' Called in the worker loop.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_tuning_result_begin (`function()`)\cr
#' Stage called at the beginning of the result writing.
#' Called in `TuningInstance*$assign_result()`.
#' The functions must have two arguments named `callback` and `context`.
#' The arguments of `$assign_result(xdt, y, learner_param_vals, extra)` are available in `context`.
#' @param on_result_begin (`function()`)\cr
#' Stage called at the beginning of the result writing.
#' Called in `OptimInstance$assign_result()`.
#' The functions must have two arguments named `callback` and `context`.
#' The arguments of `$.assign_result(xdt, y, extra)` are available in the `context`.
#' @param on_result_end (`function()`)\cr
#' Stage called after the result is written.
#' Called in `OptimInstance$assign_result()`.
#' The functions must have two arguments named `callback` and `context`.
#' The final result `instance$result` is available in the `context`.
#' @param on_result (`function()`)\cr
#' Stage called after the result is written.
#' Called in `OptimInstance$assign_result()`.
#' Deprecated.
#' Use `on_result_end` instead.
#' Stage called after the result is written.
#' Called in `OptimInstance$assign_result()`.
#' @param on_optimization_end (`function()`)\cr
#' Stage called at the end of the optimization.
#' Called in `Optimizer$optimize()`.
Expand All @@ -118,6 +156,9 @@ callback_async_tuning = function(
on_eval_before_archive = NULL,
on_optimizer_after_eval = NULL,
on_worker_end = NULL,
on_tuning_result_begin = NULL,
on_result_begin = NULL,
on_result_end = NULL,
on_result = NULL,
on_optimization_end = NULL
) {
Expand All @@ -130,6 +171,9 @@ callback_async_tuning = function(
on_eval_before_archive,
on_optimizer_after_eval,
on_worker_end,
on_tuning_result_begin,
on_result_begin,
on_result_end,
on_result,
on_optimization_end),
c(
Expand All @@ -141,10 +185,45 @@ callback_async_tuning = function(
"on_eval_before_archive",
"on_optimizer_after_eval",
"on_worker_end",
"on_tuning_result_begin",
"on_result_begin",
"on_result_end",
"on_result",
"on_optimization_end")), is.null)

if ("on_result" %in% names(stages)) {
.Deprecated(old = "on_result", new = "on_result_end")
stages$on_result_end = stages$on_result
stages$on_result = NULL
}

walk(stages, function(stage) assert_function(stage, args = c("callback", "context")))
callback = CallbackAsyncTuning$new(id, label, man)
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
}


#' @title Assertions for Callbacks
#'
#' @description
#' Assertions for [CallbackAsyncTuning] class.
#'
#' @param callback ([CallbackAsyncTuning]).
#' @param null_ok (`logical(1)`)\cr
#' If `TRUE`, `NULL` is allowed.
#'
#' @return [CallbackAsyncTuning | List of [CallbackAsyncTuning]s.
#' @export
assert_async_tuning_callback = function(callback, null_ok = FALSE) {
if (null_ok && is.null(callback)) return(invisible(NULL))
assert_class(callback, "CallbackAsyncTuning")
invisible(callback)
}

#' @export
#' @param callbacks (list of [CallbackAsyncTuning]).
#' @rdname assert_async_tuning_callback
assert_async_tuning_callbacks = function(callbacks) {
invisible(lapply(callbacks, assert_callback))
}
Loading

0 comments on commit ee1f2a4

Please sign in to comment.