Skip to content

feat: add mirai parallelization #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
^cran-comments\.md$
^CRAN-SUBMISSION$
^benchmark$
^attic$
^.cursor$

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ revdep/
# misc
Meta/
Rplots.pdf
.cursor/
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ Suggests:
codetools,
datasets,
future.callr,
mirai (>= 2.3.0),
mlr3data,
progressr,
remotes,
RhpcBLASctl,
rpart,
testthat (>= 3.2.0)
Remotes:
mlr-org/mlr3misc
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand Down
7 changes: 6 additions & 1 deletion R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,11 @@ Learner = R6Class("Learner",
#' * `"callr"`: Uses the package \CRANpkg{callr} to call the learner, measure time and do the logging.
#' This encapsulation spawns a separate R session in which the learner is called.
#' While this comes with a considerable overhead, it also guards your session from being teared down by segfaults.
#' * `"mirai"`: Uses the package \CRANpkg{mirai} to call the learner, measure time and do the logging.
#' This encapsulation calls the function in a `mirai` on a `daemon`.
#' The `daemon` can be pre-started via `daemons(1)`, otherwise a new R session will be created for each encapsulated call.
#' If a `deamon` is already running, it will be used to executed all calls.
#' Using `mirai"` is similarly safe as `callr` but much faster if several learners are encapsulated one after the other on the same daemon.
#'
#' The fallback learner is fitted to create valid predictions in case that either the model fitting or the prediction of the original learner fails.
#' If the training step or the predict step of the original learner fails, the fallback is used to make the predictions.
Expand All @@ -554,7 +559,7 @@ Learner = R6Class("Learner",
#'
#' @return `self` (invisibly).
encapsulate = function(method, fallback = NULL) {
assert_choice(method, c("none", "try", "evaluate", "callr"))
assert_choice(method, c("none", "try", "evaluate", "callr", "mirai"))

if (method != "none") {
assert_learner(fallback, task_type = self$task_type)
Expand Down
3 changes: 3 additions & 0 deletions R/helper_exec.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
if (getOption("mlr3.debug", FALSE)) {
lg$info("Running experiments sequentially in debug mode with %i iterations", n)
mapply(FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE)
} else if (requireNamespace("mirai", quietly = TRUE) && mirai::daemons_set()) {
lg$debug("Running resample() via mirai with %i iterations", n)
mirai::collect_mirai(mirai::mirai_map(data.table(...), FUN, .args = c(MoreArgs, list(is_sequential = FALSE))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it call data.table(...)? And does it have any chance of getting into conflict with any of the special args of data.table() (e.g. key, keep.rownames)

Copy link
Member Author

@be-marc be-marc May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mirai::mirai_map() only accepts data.frame and matrix when you want to map over multiple input. data.table handles list columns better. ... contains the iteration and learner when called by resample() and additionally task and resampling when called by benchmark().

And does it have any chance of getting into conflict with any of the special args

I think no. future_map is only called internally.

} else {
is_sequential = inherits(plan(), "sequential")
scheduling = if (!is_sequential && isTRUE(getOption("mlr3.exec_random", TRUE))) structure(TRUE, ordering = "random") else TRUE
Expand Down
8 changes: 8 additions & 0 deletions inst/testthat/helper_misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ with_future = function(backend, expr, ...) {
force(expr)
}

with_mirai = function(expr) {
requireNamespace("mirai")
mirai::daemons(1)
on.exit(mirai::daemons(0), add = TRUE)
force(expr)
expect_true(mirai::status()$mirai["completed"] > 0)
}

private = function(x) {
x[[".__enclos_env__"]][["private"]]
}
Expand Down
9 changes: 6 additions & 3 deletions man-roxygen/section_parallelization.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#' @section Parallelization:
#'
#' This function can be parallelized with the \CRANpkg{future} package.
#' One job is one resampling iteration, and all jobs are send to an apply function
#' from \CRANpkg{future.apply} in a single batch.
#' This function can be parallelized with the \CRANpkg{future} or \CRANpkg{mirai} package.
#' One job is one resampling iteration.
#' All jobs are send to an apply function from \CRANpkg{future.apply} or `mirai::mirai_map()` in a single batch.
#' To select a parallel backend, use [future::plan()].
#' To use `mirai`, call `mirai::daemons()` before calling this function.
#' The `future` package guarantees reproducible results independent of the parallel backend.
#' The results of `mirai` will not be the same but can be made reproducible by setting a `seed` and `dispatcher = FALSE` when calling `mirai::daemons()`.
#' More on parallelization can be found in the book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html}
5 changes: 5 additions & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions man/benchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions man/resample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

File renamed without changes.
56 changes: 56 additions & 0 deletions tests/testthat/test_parallel_mirai.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
skip_if_not_installed("mirai")

test_that("parallel resample", {
with_mirai({
task = tsk("pima")
learner = lrn("classif.rpart")
rr = resample(task, learner, rsmp("cv", folds = 3))
expect_resample_result(rr)
expect_data_table(rr$errors, nrows = 0L)
})
})

test_that("parallel benchmark", {
task = tsk("pima")
learner = lrn("classif.rpart")

with_mirai({
bmr = benchmark(benchmark_grid(task, learner, rsmp("cv", folds = 3)))
})
expect_benchmark_result(bmr)
expect_equal(bmr$aggregate(conditions = TRUE)$warnings, 0L)
expect_equal(bmr$aggregate(conditions = TRUE)$errors, 0L)
})

test_that("real parallel resample", {
with_mirai({
task = tsk("pima")
learner = lrn("classif.rpart")
rr = resample(task, learner, rsmp("cv", folds = 3))

expect_resample_result(rr)
expect_data_table(rr$errors, nrows = 0L)
})
})

test_that("data table threads are not changed in main session", {
skip_on_os("mac") # number of threads cannot be changed on mac
skip_on_cran()

old_dt_threads = getDTthreads()
on.exit({
setDTthreads(old_dt_threads)
}, add = TRUE)
setDTthreads(2L)

task = tsk("sonar")
learner = lrn("classif.debug", predict_type = "prob")
resampling = rsmp("cv", folds = 3L)
measure = msr("classif.auc")

rr1 = with_seed(123, resample(task, learner, resampling))
expect_equal(getDTthreads(), 2L)

rr2 = with_seed(123, with_mirai(resample(task, learner, resampling)))
expect_equal(getDTthreads(), 2L)
})