|
46 | 46 | #' Next, the grouping information is replaced with the respective row ids to generate training and test sets. |
47 | 47 | #' The sets can be accessed via `$train_set(i)` and `$test_set(i)`, respectively. |
48 | 48 | #' |
| 49 | +#' @section Inheriting: |
| 50 | +#' It is possible to overwrite both `private$.get_instance()` to have full control, or only `private$.sample()` when one wants to use the pre-defined mechanism for stratification and grouping. |
49 | 51 | #' |
50 | 52 | #' @template seealso_resampling |
51 | 53 | #' @export |
@@ -173,25 +175,8 @@ Resampling = R6Class("Resampling", |
173 | 175 | #' the object in its previous state. |
174 | 176 | instantiate = function(task) { |
175 | 177 | task = assert_task(as_task(task)) |
176 | | - strata = task$strata |
177 | | - groups = task$groups |
178 | | - |
179 | | - if (is.null(strata)) { |
180 | | - if (is.null(groups)) { |
181 | | - instance = private$.sample(task$row_ids, task = task) |
182 | | - } else { |
183 | | - private$.groups = groups |
184 | | - instance = private$.sample(unique(groups$group), task = task) |
185 | | - } |
186 | | - } else { |
187 | | - if (!is.null(groups)) { |
188 | | - stopf("Cannot combine stratification with grouping") |
189 | | - } |
190 | | - instance = private$.combine(lapply(strata$row_id, private$.sample, task = task)) |
191 | | - } |
192 | | - |
193 | 178 | private$.hash = NULL |
194 | | - self$instance = instance |
| 179 | + self$instance = private$.get_instance(task) |
195 | 180 | self$task_hash = task$hash |
196 | 181 | self$task_row_hash = task$row_hash |
197 | 182 | self$task_nrow = task$nrow |
@@ -261,6 +246,24 @@ Resampling = R6Class("Resampling", |
261 | 246 | .hash = NULL, |
262 | 247 | .groups = NULL, |
263 | 248 |
|
| 249 | + .get_instance = function(task) { |
| 250 | + strata = task$strata |
| 251 | + groups = task$groups |
| 252 | + if (is.null(strata)) { |
| 253 | + if (is.null(groups)) { |
| 254 | + private$.sample(task$row_ids, task = task) |
| 255 | + } else { |
| 256 | + private$.groups = groups |
| 257 | + private$.sample(unique(groups$group), task = task) |
| 258 | + } |
| 259 | + } else { |
| 260 | + if (!is.null(groups)) { |
| 261 | + stopf("Cannot combine stratification with grouping") |
| 262 | + } |
| 263 | + private$.combine(lapply(strata$row_id, private$.sample, task = task)) |
| 264 | + } |
| 265 | + }, |
| 266 | + |
264 | 267 | .get_set = function(getter, i) { |
265 | 268 | if (!self$is_instantiated) { |
266 | 269 | stopf("Resampling '%s' has not been instantiated yet", self$id) |
|
0 commit comments