From de9b7b7172e2234042277c1bd03795b5a94c533c Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 23 Apr 2025 10:40:00 +0200 Subject: [PATCH 1/4] fix(cv): number of rows must be greater than folds --- NEWS.md | 2 ++ R/ResamplingCV.R | 6 +++++- R/ResamplingRepeatedCV.R | 5 ++++- tests/testthat/test_Resampling.R | 8 ++++++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index aa54a18e5..3e3978424 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ * feat: `benchmark_grid()` will now throw a warning if you mix different predict types in the design (#1273). * feat: Converting a `BenchmarkResult` to a `data.table` now includes the `task_id`, `learner_id`, and `resampling_id` columns (#1275). +* fix: Instantiating (repeated) CV on tasks with observations less than the +number of folds now fails. # mlr3 0.23.0 diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index 34a0e60fc..f1d12aac5 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -61,9 +61,13 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, private = list( .sample = function(ids, ...) { + pvs = self$param_set$get_values() + if (length(ids) < pvs$folds) { + stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, length(ids)) + } data.table( row_id = ids, - fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L), + fold = shuffle(seq_along0(ids) %% as.integer(pvs$folds) + 1L), key = "fold" ) }, diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index bbf3b492c..ae2855422 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -93,7 +93,10 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, private = list( .sample = function(ids, ...) { - pv = self$param_set$values + pv = self$param_set$get_values() + if (length(ids) < pv$folds) { + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pv$folds, length(ids)) + } n = length(ids) folds = as.integer(pv$folds) map_dtr(seq_len(pv$repeats), function(i) { diff --git a/tests/testthat/test_Resampling.R b/tests/testthat/test_Resampling.R index db7f67d7b..51943d6e7 100644 --- a/tests/testthat/test_Resampling.R +++ b/tests/testthat/test_Resampling.R @@ -158,3 +158,11 @@ test_that("task_row_hash in Resampling works correctly", { resampling$instantiate(task) expect_identical(resampling$task_row_hash, task$row_hash) }) + +test_that("folds must be <= task size", { + cv = rsmp("cv", folds = 151) + rep_cv = rsmp("repeated_cv", folds = 151) + task = tsk("iris") + expect_error(cv$instantiate(task), "Cannot instantiate ResamplingCV with 151 folds on a task with 150 rows") + expect_error(rep_cv$instantiate(task), "Cannot instantiate ResamplingRepeatedCV with 151 folds on a task with 150 rows") +}) From 2bebb5418467d7e28a712b166b8b165c3facaccb Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 23 Apr 2025 11:22:23 +0200 Subject: [PATCH 2/4] ... --- R/Resampling.R | 4 ++++ R/ResamplingCV.R | 9 ++++++--- R/ResamplingRepeatedCV.R | 10 +++++++--- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/R/Resampling.R b/R/Resampling.R index 6cf54d97d..2986a7b58 100644 --- a/R/Resampling.R +++ b/R/Resampling.R @@ -170,6 +170,7 @@ Resampling = R6Class("Resampling", #' the object in its previous state. instantiate = function(task) { task = assert_task(as_task(task)) + private$.check(task) strata = task$strata groups = task$groups @@ -257,6 +258,9 @@ Resampling = R6Class("Resampling", .id = NULL, .hash = NULL, .groups = NULL, + .check = function(task) { + TRUE + }, .get_set = function(getter, i) { if (!self$is_instantiated) { diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index f1d12aac5..af44ffec9 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -62,15 +62,18 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, private = list( .sample = function(ids, ...) { pvs = self$param_set$get_values() - if (length(ids) < pvs$folds) { - stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, length(ids)) - } data.table( row_id = ids, fold = shuffle(seq_along0(ids) %% as.integer(pvs$folds) + 1L), key = "fold" ) }, + .check = function(task) { + pvs = self$param_set$get_values() + if (task$nrow < pvs$folds) { + stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) + } + }, .get_train = function(i) { self$instance[!list(i), "row_id", on = "fold"][[1L]] diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index ae2855422..aa84e0c9a 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -94,15 +94,19 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, private = list( .sample = function(ids, ...) { pv = self$param_set$get_values() - if (length(ids) < pv$folds) { - stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pv$folds, length(ids)) - } n = length(ids) folds = as.integer(pv$folds) map_dtr(seq_len(pv$repeats), function(i) { data.table(row_id = ids, rep = i, fold = shuffle(seq_len0(n) %% folds + 1L)) }) }, + .check = function(task) { + pvs = self$param_set$get_values() + + if (task$nrow < pvs$folds) { + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, length(ids)) + } + }, .get_train = function(i) { i = as.integer(i) - 1L From a095960921da2d348b41b052edbea6773fca3156 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 23 Apr 2025 11:30:23 +0200 Subject: [PATCH 3/4] ... --- R/ResamplingRepeatedCV.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index aa84e0c9a..5daff8123 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -104,7 +104,7 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, pvs = self$param_set$get_values() if (task$nrow < pvs$folds) { - stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, length(ids)) + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) } }, From 5cee49c74e9ba4a7e4b6208d7ae5834be58d7f92 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 24 Apr 2025 13:39:12 +0200 Subject: [PATCH 4/4] address grouping --- R/ResamplingCV.R | 6 ++++++ R/ResamplingRepeatedCV.R | 7 ++++++- tests/testthat/test_Resampling.R | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index af44ffec9..9922b492d 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -70,6 +70,12 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, }, .check = function(task) { pvs = self$param_set$get_values() + if (!is.null(task$groups)) { + n_groups = length(unique(task$groups$group)) + if (n_groups < pvs$folds) { + stopf("Cannot instantiate ResamplingCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups) + } + } if (task$nrow < pvs$folds) { stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) } diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index 5daff8123..d27ee43e4 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -102,7 +102,12 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, }, .check = function(task) { pvs = self$param_set$get_values() - + if (!is.null(task$groups)) { + n_groups = length(unique(task$groups$group)) + if (n_groups < pvs$folds) { + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups) + } + } if (task$nrow < pvs$folds) { stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) } diff --git a/tests/testthat/test_Resampling.R b/tests/testthat/test_Resampling.R index 51943d6e7..38076622f 100644 --- a/tests/testthat/test_Resampling.R +++ b/tests/testthat/test_Resampling.R @@ -165,4 +165,10 @@ test_that("folds must be <= task size", { task = tsk("iris") expect_error(cv$instantiate(task), "Cannot instantiate ResamplingCV with 151 folds on a task with 150 rows") expect_error(rep_cv$instantiate(task), "Cannot instantiate ResamplingRepeatedCV with 151 folds on a task with 150 rows") + + task$col_roles$group = "Species" + cv$param_set$set_values(folds = 4L) + rep_cv$param_set$set_values(folds = 4L) + expect_error(cv$instantiate(task), "on a grouped task with 3 groups") + expect_error(rep_cv$instantiate(task), "on a grouped task with 3 groups") })