Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add learner, resampling and measure weights #1124

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Depends:
R (>= 3.1.0)
Imports:
R6 (>= 2.4.1),
backports,
backports (>= 1.5.0),
checkmate (>= 2.0.0),
data.table (>= 1.15.0),
evaluate,
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# mlr3 (development version)

* BREAKING CHANGE: `weights` property and functionality is split into `weights_learner` and `weights_measure`:

* `weights_learner`: Weights used during training by the Learner.
* `weights_measure`: Weights used during scoring predictions via measures.

Each of these can be disabled via the new hyperparameter (Measure, Resampling) or field (Learner) `use_weights`.

# mlr3 0.22.1

* fix: Extend `assert_measure()` with checks for trained models in `assert_scorable()`.
Expand Down
60 changes: 58 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@
#' Only available for [`Learner`]s with the `"internal_tuning"` property.
#' If the learner is not trained yet, this returns `NULL`.
#'
#' @section Weights:
#'
#' Many learners support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_learner` needs to be assigned to a single numeric column.
#' If a task has weights and the learner supports them, they are used automatically.
#' If a task has weights but the learner does not support them, an error is thrown.
#' Both of these behaviors can be disabled by setting the `use_weights` field to `"ignore"`.
#' See the description of `use_weights` for more information.
#'
#' If the learner is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' When they are being used, weights are passed down to the learner directly.
#' Generally, they do not necessarily need to sum up to 1.
#'
#' @section Setting Hyperparameters:
#'
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
Expand Down Expand Up @@ -212,7 +225,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
Expand All @@ -222,6 +234,13 @@ Learner = R6Class("Learner",
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

if ("weights" %in% self$properties) {
self$use_weights = "use"
} else {
self$use_weights = "error"
}
private$.param_set = param_set

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

Expand Down Expand Up @@ -402,7 +421,7 @@ Learner = R6Class("Learner",
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weights_learner", "weights_measure")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
if (length(impute)) {
# create list with correct NA types and cbind it to the backend
Expand Down Expand Up @@ -509,6 +528,26 @@ Learner = R6Class("Learner",
),

active = list(
#' @field use_weights (`character(1)`)\cr
#' How to use weights.
#' Settings are `"use"` `"ignore"`, and `"error"`.
#'
#' * `"use"`: use weights, as supported by the underlying `Learner`.
#' * `"ignore"`: do not use weights.
#' * `"error"`: throw an error if weights are present in the training `Task`.
#'
#' For `Learner`s with the property `"weights"`, this is initialized as `"use"`.
#' For `Learner`s that do not support weights, i.e. without the `"weights"` property, this is initialized as `"error"`.
#' The latter behavior is to avoid cases where a user erroneously assumes that a `Learner` supports weights when it does not.
#' For `Learner`s that do not support weights, `use_weights` needs to be set to `"ignore"` if tasks with weights should be handled (by dropping the weights).
use_weights = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c(if ("weights" %in% self$properties) "use", "ignore", "error"))
private$.use_weights = rhs
}
private$.use_weights
},

#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
Expand Down Expand Up @@ -632,12 +671,29 @@ Learner = R6Class("Learner",
),

private = list(
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

# retrieve weights from a task, if it has weights and if the user did not
# deactivate weight usage through `self$use_weights`.
# - `task`: Task to retrieve weights from
# - `no_weights_val`: Value to return if no weights are found (default NULL)
# return: Numeric vector of weights or `no_weights_val` (default NULL)
.get_weights = function(task, no_weights_val = NULL) {
if ("weights" %nin% self$properties) {
stop("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
}
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
task$weights_learner$weight
} else {
no_weights_val
}
},

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
Expand Down
8 changes: 2 additions & 6 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,10 +76,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
8 changes: 2 additions & 6 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down Expand Up @@ -77,10 +76,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
101 changes: 86 additions & 15 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
#' In such cases it is necessary to overwrite the public methods `$aggregate()` and/or `$score()` to return a named `numeric()`
#' where at least one of its names corresponds to the `id` of the measure itself.
#'
#' @section Weights:
#'
#' Many measures support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_measure` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `"ignore"`.
#' If the measure is set-up to use weights but the task does not have a designated `weights_measure` column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are normalized by the measure.
#' See the description of `use_weights` for more information.
#'
#' @template param_id
#' @template param_param_set
#' @template param_range
Expand Down Expand Up @@ -94,10 +103,6 @@ Measure = R6Class("Measure",
#' Lower and upper bound of possible performance scores.
range = NULL,

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = NULL,

#' @field minimize (`logical(1)`)\cr
#' If `TRUE`, good predictions correspond to small values of performance scores.
minimize = NULL,
Expand All @@ -117,7 +122,6 @@ Measure = R6Class("Measure",
predict_sets = "test", task_properties = character(), packages = character(),
label = NA_character_, man = NA_character_, trafo = NULL) {

self$properties = unique(properties)
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = task_type
Expand All @@ -136,10 +140,20 @@ Measure = R6Class("Measure",
assert_choice(task_type, mlr_reflections$task_types$type)
assert_subset(properties, mlr_reflections$measure_properties[[task_type]])
assert_choice(predict_type, names(mlr_reflections$learner_predict_types[[task_type]]))
assert_subset(properties, mlr_reflections$measure_properties[[task_type]])
assert_subset(task_properties, mlr_reflections$task_properties[[task_type]])
} else {
assert_subset(properties, unique(unlist(mlr_reflections$measure_properties, use.names = FALSE)))
}

if ("weights" %in% properties) {
self$use_weights = "use"
} else if ("requires_no_prediction" %in% properties) {
self$use_weights = "ignore"
} else {
self$use_weights = "error"
}

self$properties = unique(properties)
self$predict_type = predict_type
self$predict_sets = predict_sets
self$task_properties = task_properties
Expand Down Expand Up @@ -168,6 +182,7 @@ Measure = R6Class("Measure",
catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L)))
catn(str_indent("* Properties:", self$properties))
catn(str_indent("* Predict type:", self$predict_type))
catn(str_indent("* Aggregator:", if (is.null(self$aggregator)) "mean()" else "[user-defined]"))
},

#' @description
Expand Down Expand Up @@ -195,24 +210,25 @@ Measure = R6Class("Measure",
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_scorable(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)
properties = self$properties
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% properties)

# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %in% properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %in% properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %in% properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand All @@ -227,8 +243,14 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
aggregate = function(rr) {

switch(self$average,
"macro_weighted" = {
aggregator = self$aggregator %??% weighted.mean
tab = score_measures(rr, list(self), reassemble = FALSE, view = get_private(rr)$.view,
iters = get_private(rr$resampling)$.primary_iters)
weights = private$.get_weights(rr)
set_names(aggregator(tab[[self$id]], .weights), self$id)
},
"macro" = {
aggregator = self$aggregator %??% mean
tab = score_measures(rr, list(self), reassemble = FALSE, view = get_private(rr)$.view,
Expand All @@ -245,7 +267,7 @@ Measure = R6Class("Measure",
},
"custom" = {
if (!is.null(get_private(rr$resampling)$.primary_iters) && "primary_iters" %nin% self$properties &&
!test_permutation(get_private(rr$resampling)$.primary_iters, seq_len(rr$resampling$iters))) {
!test_permutation(get_private(rr$resampling)$.primary_iters, seq_len(rr$resampling$iters))) {
stopf("Resample result has non-NULL primary_iters, but measure '%s' cannot handle them", self$id)
}
private$.aggregator(rr)
Expand Down Expand Up @@ -274,6 +296,17 @@ Measure = R6Class("Measure",
self$predict_sets, mget(private$.extra_hash, envir = self))
},

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = function(rhs) {
if (!missing(rhs)) {
props = if (is.na(self$task_type)) unique(unlist(mlr_reflections$measure_properties, use.names = FALSE)) else mlr_reflections$measure_properties[[self$task_type]]
private$.properties = assert_subset(rhs, props)
} else {
private$.properties
}
},

#' @field average (`character(1)`)\cr
#' Method for aggregation:
#'
Expand All @@ -288,7 +321,7 @@ Measure = R6Class("Measure",
#' The measure comes with a custom aggregation method which directly operates on a [ResampleResult].
average = function(rhs) {
if (!missing(rhs)) {
private$.average = assert_choice(rhs, c("micro", "macro", "custom"))
private$.average = assert_choice(rhs, c("micro", "macro", "custom", "macro_weighted"))
} else {
private$.average
}
Expand All @@ -302,14 +335,40 @@ Measure = R6Class("Measure",
} else {
private$.aggregator
}
},

#' @field use_weights (`character(1)`)\cr
#' How to handle weights.
#' Settings are `"use"`, `"ignore"`, and `"error"`.
#'
#' * `"use"`: Weights are used automatically if found in the task, as supported by the measure.
#' * `"ignore"`: Weights are ignored.
#' * `"error"`: throw an error if weights are present in the training `Task`.
#'
#' For measures with the property `"weights"`, this is initialized as `"use"`.
#' For measures with the property `"requires_no_prediction"`, this is initialized as `"ignore"`.
#' For measures that have neither of the properties, this is initialized as `"error"`.
#' The latter behavior is to avoid cases where a user erroneously assumes that a measure supports weights when it does not.
#' For measures that do not support weights, `use_weights` needs to be set to `"ignore"` if tasks with weights should be handled (by dropping the weights).
use_weights = function(rhs) {
if (!missing(rhs)) {
private$.use_weights = assert_choice(rhs, c("use", "ignore", "error"))
} else {
private$.use_weights
}
}
),

private = list(
.properties = character(),
.predict_sets = NULL,
.extra_hash = character(),
.average = NULL,
.aggregator = NULL
.aggregator = NULL,
.use_weights = NULL,
.score = function(prediction, task, weights, ...) {
stop("abstract method")
}
)
)

Expand Down Expand Up @@ -364,7 +423,8 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
return(NaN)
}

get_private(measure)$.score(prediction = prediction, task = task, learner = learner, train_set = train_set)
get_private(measure)$.score(prediction = prediction, task = task, learner = learner, train_set = train_set,
weights = if (measure$use_weights == "use") task$weights_measure[list(prediction$row_ids), "weight"][[1L]])
}

#' @title Workhorse function to calculate multiple scores
Expand All @@ -387,6 +447,17 @@ score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters =
reassemble_learners = reassemble ||
some(measures, function(m) any(c("requires_learner", "requires_model") %in% m$properties))
tab = get_private(obj)$.data$as_data_table(view = view, reassemble_learners = reassemble_learners, convert_predictions = FALSE)
if ("weights_measure" %in% tab$task$properties) {
weightsumgetter = function(task, prediction) {
sum(task$weights_measure[list(prediction$row_ids), "weight"][[1L]])
}
} else {
# no weights recorded, use unit weights
weightsumgetter = function(task, prediction) {
as.numeric(length(prediction$row_ids)) # should explicitly be a numeric, not an integer
}
}
set(tab, j = ".weights", value = pmap_dbl(tab[, c("task", "prediction"), with = FALSE], weightsumgetter))

if (!is.null(iters)) {
tab = tab[list(iters), on = "iteration"]
Expand Down
Loading