Skip to content

feat: add learner, resampling and measure weights #1124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
30639e7
refactor: weights (#1065)
mllg Aug 23, 2024
eef7717
fix: default fallback learner
be-marc Aug 23, 2024
6b73434
MeasureRegrRSQ: add comment line for doccing
berndbischl Aug 24, 2024
da4d632
... (#1118)
berndbischl Aug 24, 2024
78e102a
progress
mb706 Aug 24, 2024
37baa67
...
berndbischl Aug 28, 2024
c32bc21
...
berndbischl Aug 28, 2024
e70f028
...
berndbischl Aug 28, 2024
89202f5
...
berndbischl Aug 28, 2024
7bb73c1
Merge branch 'master' into weights_ii
mb706 Dec 19, 2024
bcd1f73
fix oversight in Task.R
mb706 Dec 19, 2024
7f11eeb
progress
mb706 Dec 19, 2024
1ed098c
progress
mb706 Dec 20, 2024
7d24106
no resampling weights
mb706 Dec 20, 2024
404f659
no more weights_resampling
mb706 Dec 20, 2024
d5ba317
progress
mb706 Apr 15, 2025
2823803
Merge branch 'master' into weights_ii
mb706 Apr 15, 2025
59a84ad
fix doc issue
mb706 Apr 15, 2025
6252614
document
mb706 Apr 15, 2025
f1e194d
going to start implementing measure weights properly
mb706 Apr 16, 2025
62667f4
progress
mb706 Apr 16, 2025
de0f8e8
undo bug
mb706 Apr 16, 2025
3998277
autotests
mb706 Apr 16, 2025
0711e4c
document
mb706 Apr 16, 2025
103f472
tests
mb706 Apr 16, 2025
85c043f
finishing tests
mb706 Apr 16, 2025
473a2a6
revert a few changes since we are not doing resampling weights currently
mb706 Apr 16, 2025
327bd53
some cleanup
mb706 Apr 16, 2025
63594d8
make tests pass
mb706 Apr 16, 2025
a16fa00
some more cleaning up
mb706 Apr 16, 2025
0a1828c
some more cleanup
mb706 Apr 16, 2025
791f19a
that is hopefully all
mb706 Apr 16, 2025
a046919
remove some measures superseded by MeasureRegrRSQ
mb706 Apr 17, 2025
51cfaaa
remove confusion_weighted
mb706 Apr 17, 2025
522c180
Adjust some measures and document
mb706 Apr 17, 2025
e91b8d9
adjust tests
mb706 Apr 17, 2025
3e45c59
learner printer extended
mb706 Apr 17, 2025
9566b42
news entry
mb706 Apr 17, 2025
ade2c8d
docs changes
mb706 Apr 22, 2025
c9ac4b5
document()
mb706 Apr 22, 2025
5c15980
fix autotest
be-marc May 12, 2025
61e87e0
Merge branch 'weights_ii' of github.com:mlr-org/mlr3 into weights_ii
be-marc May 12, 2025
ee1e9fb
Merge branch 'main' into weights_ii
be-marc May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Depends:
R (>= 3.3.0)
Imports:
R6 (>= 2.4.1),
backports,
backports (>= 1.5.0),
checkmate (>= 2.0.0),
data.table (>= 1.15.0),
evaluate,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ importFrom(stats,rnorm)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,terms)
importFrom(stats,weighted.mean)
importFrom(utils,bibentry)
importFrom(utils,data)
importFrom(utils,getFromNamespace)
Expand Down
12 changes: 12 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# mlr3 (development version)

* BREAKING CHANGE: `weights` property and functionality is split into `weights_learner` and `weights_measure`:

* `weights_learner`: Weights used during training by the Learner.
* `weights_measure`: Weights used during scoring predictions via measures.

Each of these can be disabled via the new field `use_weights` in `Learner` and `Measure` objects.
* feat: Add `$confusion_weighted` field to `PredictionClassif`.
* feat: Add `$weights` field to `Prediction`. It contains the `weights_measure` weights from the `Task` that was used for prediction.
* feat: Add `"macro_weighted"` option to `Measure$average` field.
* feat: `MeasureRegrRSQ` and `MeasureClassifCost` gain `"weights"` property.
* feat: `LearnerClassifFeatureless`, `LearnerRegrFeatureless`, `LearnerClassifDebug`, `LearnerRegrDebug` gain `"weights"` property.
* feat: `Learner` printer now prints information about encapsulation and weights use.
* feat: Add `score_roc_measures()` to score a prediction on various roc measures.
* feat: A better error message is thrown, which often happens when incorrectly configuring the `validate` field
of a `GraphLearner`
Expand Down
91 changes: 87 additions & 4 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@
#' Only available for [`Learner`]s with the `"internal_tuning"` property.
#' If the learner is not trained yet, this returns `NULL`.
#'
#' @section Weights:
#'
#' Many learners support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_learner` needs to be assigned to a single numeric column.
#' If a task has weights and the learner supports them, they are used automatically.
#' If a task has weights but the learner does not support them, an error is thrown by default.
#' Both of these behaviors can be disabled by setting the `use_weights` field to `"ignore"`.
#' See the description of `use_weights` for more information.
#'
#' If the learner is set-up to use weights but the task does not have a designated weight column, samples are considered to have equal weight.
#' When weights are being used, they are passed down to the learner directly; the effect of weights depends on the specific learner.
#' Generally, weights do not need to sum up to 1.
#'
#' When implementing a Learner that uses weights, the `"weights"` property should be set.
#' The `$.train()`-method should then call the `$.get_weights()`-method to retrieve the weights from the task.
#' `$.get_weights()` will automatically discard weights when `use_weights` is set to `"ignore"`;
#'
#' @section Setting Hyperparameters:
#'
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
Expand Down Expand Up @@ -207,7 +224,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
Expand All @@ -217,6 +233,13 @@ Learner = R6Class("Learner",
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

if ("weights" %in% self$properties) {
self$use_weights = "use"
} else {
self$use_weights = "error"
}
private$.param_set = param_set

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

Expand All @@ -238,7 +261,9 @@ Learner = R6Class("Learner",
catn(str_indent("* Packages:", self$packages))
catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))))
catn(str_indent("* Feature Types:", self$feature_types))
catn(str_indent("* Encapsulation:", paste0(self$encapsulation[[1]], " (fallback: ", if (is.null(self$fallback)) "-" else class(self$fallback)[[1L]], ")")))
catn(str_indent("* Properties:", self$properties))
catn(str_indent("* Other settings:", paste0("use_weights = '", self$use_weights, "'")))
w = self$warnings
e = self$errors
if (length(w)) {
Expand Down Expand Up @@ -377,6 +402,10 @@ Learner = R6Class("Learner",
#' Further, [`auto_convert`] is used for type-conversions to ensure compatability
#' of features between `$train()` and `$predict()`.
#'
#' If the stored training task has a `weights_measure` column, *and* if `newdata` contains a column with the same name,
#' that column must be numeric with no missing values and is used as measure weights column.
#' Otherwise, no measure weights are used.
#'
#' @param newdata (any object supported by [as_data_backend()])\cr
#' New data to predict on.
#' All data formats convertible by [as_data_backend()] are supported, e.g.
Expand All @@ -403,7 +432,8 @@ Learner = R6Class("Learner",
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
# We do not impute weighs_measure, because we decidedly do not have weights_measure in this case.
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
tab1 = if (length(impute)) {
# create list with correct NA types and cbind it to the backend
Expand Down Expand Up @@ -432,6 +462,20 @@ Learner = R6Class("Learner",
task$col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$col_info$fix_factor_levels[is.na(task$col_info$fix_factor_levels)] = FALSE
task$row_roles$use = task$backend$rownames
task_col_roles = task$col_roles
update_col_roles = FALSE
if (any(task_col_roles$weights_measure %nin% newdata$colnames)) {
update_col_roles = TRUE
task_col_roles$weights_measure = character(0)
}
if (any(task_col_roles$weights_learner %nin% newdata$colnames)) {
update_col_roles = TRUE
task_col_roles$weights_learner = character(0)
}
if (update_col_roles) {
task$col_roles = task_col_roles
}

self$predict(task)
},

Expand Down Expand Up @@ -585,6 +629,28 @@ Learner = R6Class("Learner",
),

active = list(
#' @field use_weights (`character(1)`)\cr
#' How weights should be handled.
#' Settings are `"use"` `"ignore"`, and `"error"`.
#'
#' * `"use"`: use weights, as supported by the underlying `Learner`.
#' Only available for `Learner`s with the property `"weights"`.
#' * `"ignore"`: do not use weights.
#' * `"error"`: throw an error if weights are present in the training `Task`.
#'
#' For `Learner`s with the property `"weights"`, this is initialized as `"use"`.
#' For `Learner`s that do not support weights, i.e. without the `"weights"` property, this is initialized as `"error"`.
#' The latter behavior is to avoid cases where a user erroneously assumes that a `Learner` supports weights when it does not.
#' For `Learner`s that do not support weights, `use_weights` needs to be set to `"ignore"` if tasks with weights should be handled (by dropping the weights).
#' See Section 'weights' for more details.
use_weights = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c(if ("weights" %in% self$properties) "use", "ignore", "error"))
private$.use_weights = rhs
}
private$.use_weights
},

#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
Expand Down Expand Up @@ -645,15 +711,15 @@ Learner = R6Class("Learner",
hash = function(rhs) {
assert_ro_binding(rhs)
calculate_hash(class(self), self$id, self$param_set$values, private$.predict_type,
self$fallback$hash, self$parallel_predict, get0("validate", self), self$predict_sets)
self$fallback$hash, self$parallel_predict, get0("validate", self), self$predict_sets, private$.use_weights)
},

#' @field phash (`character(1)`)\cr
#' Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).
phash = function(rhs) {
assert_ro_binding(rhs)
calculate_hash(class(self), self$id, private$.predict_type,
self$fallback$hash, self$parallel_predict, get0("validate", self))
self$fallback$hash, self$parallel_predict, get0("validate", self), private$.use_weights)
},

#' @field predict_type (`character(1)`)\cr
Expand Down Expand Up @@ -728,6 +794,7 @@ Learner = R6Class("Learner",
),

private = list(
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
.predict_type = NULL,
Expand All @@ -736,6 +803,22 @@ Learner = R6Class("Learner",
.hotstart_stack = NULL,
.selected_features_impute = "error",

# retrieve weights from a task, if it has weights and if the user did not
# deactivate weight usage through `self$use_weights`.
# - `task`: Task to retrieve weights from
# - `no_weights_val`: Value to return if no weights are found (default NULL)
# return: Numeric vector of weights or `no_weights_val` (default NULL)
.get_weights = function(task, no_weights_val = NULL) {
if ("weights" %nin% self$properties) {
stop("private$.get_weights should not be used in Learners that do not have the 'weights' property.")
}
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
task$weights_learner$weight
} else {
no_weights_val
}
},

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
param_set = param_set,
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "prob"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "validation", "internal_tuning", "marshal"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "validation", "internal_tuning", "marshal", "weights"),
man = "mlr3::mlr_learners_classif.debug",
label = "Debug Learner for Classification"
)
Expand Down Expand Up @@ -191,7 +191,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
}

model = list(
response = as.character(sample(task$truth(), 1L)),
response = as.character(sample(task$truth(), 1L, prob = private$.get_weights(task))),
pid = Sys.getpid(),
id = UUIDgenerate(),
random_number = sample(100000, 1),
Expand Down
11 changes: 8 additions & 3 deletions R/LearnerClassifFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
#' \item{mode:}{
#' Predicts the most frequent label. If there are two or more labels tied, randomly selects one per prediction.
#' Probabilities correspond to the relative frequency of the class labels in the training set.
#' For weighted data, the label(s) with the highest weighted frequency are selected.
#' }
#' \item{sample:}{
#' Randomly predict a label uniformly.
#' Probabilities correspond to a uniform distribution of class labels, i.e. 1 divided by the number of classes.
#' Weights are ignored, if present.
#' }
#' \item{weighted.sample:}{
#' Randomly predict a label, with probability estimated from the training distribution.
#' For consistency, probabilities are 1 for the sampled label and 0 for all other labels.
#' For weighted data, sample weights are used to weight the class labels.
#' }
#' }
#'
Expand All @@ -40,7 +43,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
feature_types = mlr_reflections$task_feature_types,
predict_types = c("response", "prob"),
param_set = ps,
properties = c("featureless", "twoclass", "multiclass", "missings", "importance", "selected_features"),
properties = c("featureless", "twoclass", "multiclass", "missings", "importance", "selected_features", "weights"),
label = "Featureless Classification Learner",
man = "mlr3::mlr_learners_classif.featureless",
)
Expand All @@ -67,8 +70,10 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn

private = list(
.train = function(task) {
tn = task$target_names
set_class(list(tab = table(task$data(cols = tn)[[1L]]), features = task$feature_names), "classif.featureless_model")
weights = NULL
counts_table = data.table(truth = task$truth(), weights = private$.get_weights(task, 1.0))[, list(weights = sum(weights)), by = "truth"]
tab = set_names(counts_table$weights, counts_table$truth)
set_class(list(tab = tab, features = task$feature_names), "classif.featureless_model")
},

.predict = function(task) {
Expand Down
8 changes: 2 additions & 6 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$set_values(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,10 +76,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %chin% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
11 changes: 7 additions & 4 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
threads = p_int(1L, tags = c("train", "threads")),
x = p_dbl(0, 1, tags = "train")
),
properties = "missings",
properties = c("missings", "weights"),
packages = "stats",
man = "mlr3::mlr_learners_regr.debug",
label = "Debug Learner for Regression"
)
Expand Down Expand Up @@ -75,15 +76,17 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
truth = task$truth()
weights = private$.get_weights(task)
wmd = weighted_mean_sd(truth, weights)
model = list(
response = mean(truth),
se = sd(truth),
response = wmd$mean,
se = wmd$sd,
pid = Sys.getpid()
)

if (self$predict_type == "quantiles") {
probs = self$quantiles
model$quantiles = unname(quantile(truth, probs))
model$quantiles = unname(quantile_weighted(truth, probs, weights = weights))
model$quantile_probs = probs
}

Expand Down
27 changes: 20 additions & 7 deletions R/LearnerRegrFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
#' If `robust` is `TRUE`, [median()] and [mad()] are used instead of [mean()] and [sd()],
#' respectively.
#'
#' For weighted data, the response is the weighted mean (weighted median for robust regression).
#' The predicted standard error is the square root of the weighted variance estimator with bias correction
#' based on effective degrees of freedom:
#' ```
#' sd(y, weights) = sqrt(
#' sum(weights * (y - weighted.mean(y, weights))^2) /
#' (sum(weights) - sum(weights ^2) / sum(weights))
#' )
#' ```
#' If `robust` is `TRUE`, the weighted median absolute deviation is used, adjusted by a factor of 1.4826
#' for consistency with [mad()].
#'
#' @templateVar id regr.featureless
#' @template learner
#'
Expand All @@ -30,7 +42,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
feature_types = unname(mlr_reflections$task_feature_types),
predict_types = c("response", "se", "quantiles"),
param_set = ps,
properties = c("featureless", "missings", "importance", "selected_features"),
properties = c("featureless", "missings", "importance", "selected_features", "weights"),
packages = "stats",
label = "Featureless Regression Learner",
man = "mlr3::mlr_learners_regr.featureless"
Expand Down Expand Up @@ -61,20 +73,21 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
x = task$data(cols = task$target_names)[[1L]]

weights = private$.get_weights(task)
quantiles = if (self$predict_type == "quantiles") {
if (is.null(private$.quantiles) || is.null(private$.quantile_response)) {
stopf("Quantiles '$quantiles' and response quantile '$quantile_response' must be set")
}
quantile(x, probs = private$.quantiles)
quantile_weighted(x, probs = private$.quantiles, weights = weights)
}

if (isFALSE(pv$robust)) {
location = mean(x)
dispersion = sd(x)
wmd = weighted_mean_sd(x, weights)
location = wmd$mean
dispersion = wmd$sd
} else {
location = stats::median(x)
dispersion = stats::mad(x, center = location)
location = quantile_weighted(x, probs = 0.5, weights = weights, continuous = FALSE)
dispersion = quantile_weighted(abs(x - location), probs = 0.5, weights = weights, continuous = FALSE) * 1.4826
}

set_class(list(
Expand Down
Loading