Skip to content

Commit 6bc3f3d

Browse files
bblodfonbe-marc
authored andcommitted
fix: xgboost offset/base_margin (#371)
1 parent a76a912 commit 6bc3f3d

19 files changed

+333
-181
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,5 @@ Collate:
8484
'helpers.R'
8585
'helpers_glmnet.R'
8686
'helpers_ranger.R'
87+
'helpers_xgboost.R'
8788
'zzz.R'

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# mlr3learners (development version)
22

3+
* fix: using offset during prediction for `xgboost` learners
4+
35
# mlr3learners 0.14.0
46

57
* compatibility: xgboost 3.1.2.1

R/LearnerClassifCVGlmnet.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
5858
mxit = p_int(1L, default = 100L, tags = "train"),
5959
nfolds = p_int(3L, default = 10L, tags = "train"),
6060
nlambda = p_int(1L, default = 100L, tags = "train"),
61-
use_pred_offset = p_lgl(default = TRUE, tags = "predict"),
61+
use_pred_offset = p_lgl(init = TRUE, tags = "predict"),
6262
parallel = p_lgl(default = FALSE, tags = "train"),
6363
penalty.factor = p_uty(tags = "train"),
6464
pmax = p_int(0L, tags = "train"),
@@ -78,8 +78,6 @@ LearnerClassifCVGlmnet = R6Class("LearnerClassifCVGlmnet",
7878
upper.limits = p_uty(tags = "train")
7979
)
8080

81-
ps$set_values(use_pred_offset = TRUE)
82-
8381
super$initialize(
8482
id = "classif.cv_glmnet",
8583
param_set = ps,

R/LearnerClassifGlmnet.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
6565
mxit = p_int(1L, default = 100L, tags = "train"),
6666
mxitnr = p_int(1L, default = 25L, tags = "train"),
6767
nlambda = p_int(1L, default = 100L, tags = "train"),
68-
use_pred_offset = p_lgl(default = TRUE, tags = "predict"),
68+
use_pred_offset = p_lgl(init = TRUE, tags = "predict"),
6969
penalty.factor = p_uty(tags = "train"),
7070
pmax = p_int(0L, tags = "train"),
7171
pmin = p_dbl(0, 1, default = 1.0e-9, tags = "train"),
@@ -82,8 +82,6 @@ LearnerClassifGlmnet = R6Class("LearnerClassifGlmnet",
8282
upper.limits = p_uty(tags = "train")
8383
)
8484

85-
ps$set_values(use_pred_offset = TRUE)
86-
8785
super$initialize(
8886
id = "classif.glmnet",
8987
param_set = ps,

R/LearnerClassifXgboost.R

Lines changed: 30 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -98,55 +98,55 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
9898
ps = ps(
9999
alpha = p_dbl(0, default = 0, tags = "train"),
100100
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
101-
base_score = p_dbl(default = 0.5, tags = "train"),
101+
base_score = p_dbl(tags = "train"),
102102
booster = p_fct(c("gbtree", "gblinear", "dart"), default = "gbtree", tags = "train"),
103103
callbacks = p_uty(default = list(), tags = "train"),
104-
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train"),
105-
colsample_bynode = p_dbl(0, 1, default = 1, tags = "train"),
106-
colsample_bytree = p_dbl(0, 1, default = 1, tags = "train"),
104+
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train", depends = quote(booster == "gbtree")),
105+
colsample_bynode = p_dbl(0, 1, default = 1, tags = "train", depends = quote(booster == "gbtree")),
106+
colsample_bytree = p_dbl(0, 1, default = 1, tags = "train", depends = quote(booster == "gbtree")),
107107
device = p_uty(default = "cpu", tags = "train"),
108108
disable_default_eval_metric = p_lgl(default = FALSE, tags = "train"),
109109
early_stopping_rounds = p_int(1L, default = NULL, special_vals = list(NULL), tags = "train"),
110110
eta = p_dbl(0, 1, default = 0.3, tags = "train"),
111111
evals = p_uty(default = NULL, tags = "train"),
112112
eval_metric = p_uty(tags = "train"),
113113
custom_metric = p_uty(tags = "train", custom_check = crate({function(x) check_true(any(is.function(x), test_multi_class(x, c("MeasureClassifSimple", "MeasureBinarySimple"))))})),
114-
extmem_single_page = p_lgl(default = FALSE, tags = "train"),
114+
extmem_single_page = p_lgl(default = FALSE, tags = "train", depends = quote(tree_method %in% c("hist", "approx"))),
115115
feature_selector = p_fct(c("cyclic", "shuffle", "random", "greedy", "thrifty"), default = "cyclic", tags = "train", depends = quote(booster == "gblinear")),
116116
gamma = p_dbl(0, default = 0, tags = "train"),
117-
grow_policy = p_fct(c("depthwise", "lossguide"), default = "depthwise", tags = "train", depends = quote(tree_method == "hist")),
118-
interaction_constraints = p_uty(tags = "train"),
117+
grow_policy = p_fct(c("depthwise", "lossguide"), default = "depthwise", tags = "train", depends = quote(booster == "gbtree" && tree_method %in% c("hist", "approx"))),
118+
interaction_constraints = p_uty(tags = "train", depends = quote(booster == "gbtree")),
119119
iterationrange = p_uty(tags = "predict"),
120120
lambda = p_dbl(0, default = 1, tags = "train"),
121-
max_bin = p_int(2L, default = 256L, tags = "train", depends = quote(tree_method == "hist")),
122-
max_cached_hist_node = p_int(default = 65536L, tags = "train", depends = quote(tree_method == "hist")),
123-
max_cat_to_onehot = p_int(tags = "train"),
124-
max_cat_threshold = p_dbl(tags = "train"),
125-
max_delta_step = p_dbl(0, default = 0, tags = "train"),
126-
max_depth = p_int(0L, default = 6L, tags = "train"),
127-
max_leaves = p_int(0L, default = 0L, tags = "train", depends = quote(grow_policy == "lossguide")),
121+
max_bin = p_int(2L, default = 256L, tags = "train", depends = quote(tree_method %in% c("hist", "approx"))),
122+
max_cached_hist_node = p_int(default = 65536L, tags = "train", depends = quote(tree_method %in% c("hist", "approx"))),
123+
max_cat_to_onehot = p_int(tags = "train", depends = quote(tree_method %in% c("hist", "approx"))),
124+
max_cat_threshold = p_dbl(tags = "train", depends = quote(tree_method %in% c("hist", "approx"))),
125+
max_delta_step = p_dbl(0, default = 0, tags = "train", depends = quote(booster == "gbtree")),
126+
max_depth = p_int(0L, default = 6L, tags = "train", depends = quote(booster == "gbtree")),
127+
max_leaves = p_int(0L, default = 0L, tags = "train", depends = quote(booster == "gbtree")),
128128
maximize = p_lgl(default = NULL, special_vals = list(NULL), tags = "train"),
129-
min_child_weight = p_dbl(0, default = 1, tags = "train"),
129+
min_child_weight = p_dbl(0, default = 1, tags = "train", depends = quote(booster == "gbtree")),
130130
missing = p_dbl(default = NA, tags = "predict", special_vals = list(NA, NA_real_, NULL)),
131-
monotone_constraints = p_uty(default = 0, tags = "train", custom_check = crate(function(x) { checkmate::check_integerish(x, lower = -1, upper = 1, any.missing = FALSE) })), # nolint
131+
monotone_constraints = p_uty(default = 0, tags = "train", custom_check = crate(function(x) { checkmate::check_integerish(x, lower = -1, upper = 1, any.missing = FALSE) }), depends = quote(booster == "gbtree")), # nolint
132132
nrounds = p_nrounds,
133133
normalize_type = p_fct(c("tree", "forest"), default = "tree", tags = "train", depends = quote(booster == "dart")),
134134
nthread = p_int(1L, init = 1L, tags = c("train", "threads")),
135-
num_parallel_tree = p_int(1L, default = 1L, tags = "train"),
135+
num_parallel_tree = p_int(1L, default = 1L, tags = "train", depends = quote(booster == "gbtree")),
136136
objective = p_uty(default = "binary:logistic", tags = c("train", "predict")),
137137
one_drop = p_lgl(default = FALSE, tags = "train", depends = quote(booster == "dart")),
138138
print_every_n = p_int(1L, default = 1L, tags = "train", depends = quote(verbose == 1L)),
139139
rate_drop = p_dbl(0, 1, default = 0, tags = "train", depends = quote(booster == "dart")),
140-
refresh_leaf = p_lgl(default = TRUE, tags = "train"),
140+
refresh_leaf = p_lgl(default = TRUE, tags = "train", depends = quote(booster == "gbtree")),
141141
seed = p_int(tags = "train"),
142142
seed_per_iteration = p_lgl(default = FALSE, tags = "train"),
143143
sampling_method = p_fct(c("uniform", "gradient_based"), default = "uniform", tags = "train", depends = quote(booster == "gbtree")),
144144
sample_type = p_fct(c("uniform", "weighted"), default = "uniform", tags = "train", depends = quote(booster == "dart")),
145145
save_name = p_uty(default = NULL, tags = "train"),
146146
save_period = p_int(0, default = NULL, special_vals = list(NULL), tags = "train"),
147-
scale_pos_weight = p_dbl(default = 1, tags = "train"),
147+
scale_pos_weight = p_dbl(default = 1, tags = "train", depends = quote(booster == "gbtree")),
148148
skip_drop = p_dbl(0, 1, default = 0, tags = "train", depends = quote(booster == "dart")),
149-
subsample = p_dbl(0, 1, default = 1, tags = "train"),
149+
subsample = p_dbl(0, 1, default = 1, tags = "train", depends = quote(booster == "gbtree")),
150150
top_k = p_int(0, default = 0, tags = "train", depends = quote(feature_selector %in% c("greedy", "thrifty") && booster == "gblinear")),
151151
training = p_lgl(default = FALSE, tags = "predict"),
152152
tree_method = p_fct(c("auto", "exact", "approx", "hist", "gpu_hist"), default = "auto", tags = "train", depends = quote(booster %in% c("gbtree", "dart"))),
@@ -156,7 +156,8 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
156156
validate_features = p_lgl(default = TRUE, tags = "predict"),
157157
verbose = p_int(0L, 2L, init = 0L, tags = "train"),
158158
verbosity = p_int(0L, 2L, init = 0L, tags = "train"),
159-
xgb_model = p_uty(default = NULL, tags = "train")
159+
xgb_model = p_uty(default = NULL, tags = "train"),
160+
use_pred_offset = p_lgl(init = TRUE, tags = "predict")
160161
)
161162

162163
super$initialize(
@@ -190,7 +191,7 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
190191
#' @field internal_valid_scores (named `list()` or `NULL`)
191192
#' The validation scores extracted from `model$evaluation_log`.
192193
#' If early stopping is activated, this contains the validation scores of the model for the optimal `nrounds`,
193-
#' otherwise the `nrounds` for the final model.
194+
#' otherwise the scores are taken from the final boosting round `nrounds`.
194195
internal_valid_scores = function() {
195196
self$state$internal_valid_scores
196197
},
@@ -261,22 +262,8 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
261262
xgboost::setinfo(xgb_data, "weight", weights)
262263
}
263264

264-
if ("offset" %in% task$properties) {
265-
offset = task$offset
266-
if (nlvls == 2L) {
267-
# binary case
268-
base_margin = offset$offset
269-
} else {
270-
# multiclass needs a matrix (n_samples, n_classes)
271-
# it seems reasonable to reorder according to label (0,1,2,...)
272-
reordered_cols = paste0("offset_", rev(levels(task$truth())))
273-
n_offsets = ncol(offset) - 1 # all expect `row_id`
274-
if (length(reordered_cols) != n_offsets) {
275-
stopf("Task has %i class labels, and only %i offset columns are provided",
276-
nlevels(task$truth()), n_offsets)
277-
}
278-
base_margin = as_numeric_matrix(offset)[, reordered_cols]
279-
}
265+
base_margin = xgboost_get_base_margin(task, "train", pv)
266+
if (!is.null(base_margin)) {
280267
xgboost::setinfo(xgb_data, "base_margin", base_margin)
281268
}
282269

@@ -292,22 +279,13 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
292279
xgb_valid_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(valid_data), label = valid_label)
293280

294281
weights = get_weights(internal_valid_task, private)
295-
296282
if (!is.null(weights)) {
297283
xgboost::setinfo(xgb_valid_data, "weight", weights)
298284
}
299285

300-
if ("offset" %in% internal_valid_task$properties) {
301-
valid_offset = internal_valid_task$offset
302-
if (nlvls == 2L) {
303-
base_margin = valid_offset$offset
304-
} else {
305-
# multiclass needs a matrix (n_samples, n_classes)
306-
# it seems reasonable to reorder according to label (0,1,2,...)
307-
reordered_cols = paste0("offset_", rev(levels(internal_valid_task$truth())))
308-
base_margin = as_numeric_matrix(valid_offset)[, reordered_cols]
309-
}
310-
xgboost::setinfo(xgb_valid_data, "base_margin", base_margin)
286+
valid_base_margin = xgboost_get_base_margin(internal_valid_task, "train", pv)
287+
if (!is.null(base_margin)) {
288+
xgboost::setinfo(xgb_valid_data, "base_margin", valid_base_margin)
311289
}
312290

313291
pv$evals = c(pv$evals, list(test = xgb_valid_data))
@@ -371,6 +349,8 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
371349
pv$objective = ifelse(nlvls == 2L, "binary:logistic", "multi:softprob")
372350
}
373351

352+
pv$base_margin = xgboost_get_base_margin(task, "predict", pv)
353+
374354
newdata = as_numeric_matrix(ordered_features(task, self))
375355
pred = invoke(predict, model, newdata = newdata, .args = pv)
376356
if (nlvls == 2L) { # binaryclass
@@ -465,7 +445,6 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
465445
)
466446
)
467447

468-
469448
#' @export
470449
default_values.LearnerClassifXgboost = function(x, search_space, task, ...) { # nolint
471450
special_defaults = list(
@@ -477,56 +456,3 @@ default_values.LearnerClassifXgboost = function(x, search_space, task, ...) { #
477456

478457
#' @include aaa.R
479458
learners[["classif.xgboost"]] = LearnerClassifXgboost
480-
481-
# mlr3 measure to custom inner measure functions
482-
xgboost_binary_binary_prob = function(pred, dtrain, measure, ...) {
483-
# label is a vector of labels (0, 1)
484-
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
485-
# pred is a vector of log odds
486-
# transform log odds to probabilities
487-
pred = 1 / (1 + exp(-pred))
488-
measure$fun(truth, pred, positive = "1")
489-
}
490-
491-
xgboost_binary_classif_prob = function(pred, dtrain, measure, ...) {
492-
# label is a vector of labels (0, 1)
493-
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
494-
# pred is a vector of log odds
495-
# transform log odds to probabilities
496-
pred = 1 / (1 + exp(-pred))
497-
# multiclass measure needs a matrix of probabilities
498-
pred_mat = matrix(c(pred, 1 - pred), ncol = 2)
499-
colnames(pred_mat) = c("1", "0")
500-
measure$fun(truth, pred_mat, positive = "1")
501-
}
502-
503-
xgboost_binary_response = function(pred, dtrain, measure, ...) {
504-
# label is a vector of labels (0, 1)
505-
truth = factor(xgboost::getinfo(dtrain, "label"), levels = c(0, 1))
506-
# pred is a vector of log odds
507-
response = factor(as.integer(pred > 0), levels = c(0, 1))
508-
measure$fun(truth, response)
509-
}
510-
511-
xgboost_multiclass_prob = function(pred, dtrain, measure, n_classes, ...) {
512-
# label is a vector of labels (0, 1, ..., n_classes - 1)
513-
truth = factor(xgboost::getinfo(dtrain, "label"), levels = seq_len(n_classes) - 1L)
514-
515-
# pred is a matrix of log odds for each class
516-
# transform log odds to probabilities
517-
pred_exp = exp(pred)
518-
pred_mat = pred_exp / rowSums(pred_exp)
519-
colnames(pred_mat) = levels(truth)
520-
521-
measure$fun(truth, pred_mat)
522-
}
523-
524-
xgboost_multiclass_response = function(pred, dtrain, measure, n_classes, ...) {
525-
# label is a vector of labels (0, 1, ..., n_classes - 1)
526-
truth = factor(xgboost::getinfo(dtrain, "label"), levels = seq_len(n_classes) - 1L)
527-
528-
# pred is a matrix of log odds for each class
529-
response = factor(max.col(pred, ties.method = "random") - 1, levels = levels(truth))
530-
measure$fun(truth, response)
531-
}
532-

R/LearnerRegrCVGlmnet.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
5353
mxitnr = p_int(1L, default = 25L, tags = "train"),
5454
nfolds = p_int(3L, default = 10L, tags = "train"),
5555
nlambda = p_int(1L, default = 100L, tags = "train"),
56-
use_pred_offset = p_lgl(default = TRUE, tags = "predict"),
56+
use_pred_offset = p_lgl(init = TRUE, tags = "predict"),
5757
parallel = p_lgl(default = FALSE, tags = "train"),
5858
penalty.factor = p_uty(tags = "train"),
5959
pmax = p_int(0L, tags = "train"),
@@ -73,7 +73,7 @@ LearnerRegrCVGlmnet = R6Class("LearnerRegrCVGlmnet",
7373
upper.limits = p_uty(tags = "train")
7474
)
7575

76-
ps$set_values(family = "gaussian", use_pred_offset = TRUE)
76+
ps$set_values(family = "gaussian")
7777

7878
super$initialize(
7979
id = "regr.cv_glmnet",

R/LearnerRegrGlmnet.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
5252
mnlam = p_int(1L, default = 5L, tags = "train"),
5353
mxit = p_int(1L, default = 100L, tags = "train"),
5454
mxitnr = p_int(1L, default = 25L, tags = "train"),
55-
use_pred_offset = p_lgl(default = TRUE, tags = "predict"),
55+
use_pred_offset = p_lgl(init = TRUE, tags = "predict"),
5656
nlambda = p_int(1L, default = 100L, tags = "train"),
5757
parallel = p_lgl(default = FALSE, tags = "train"),
5858
penalty.factor = p_uty(tags = "train"),
@@ -71,7 +71,7 @@ LearnerRegrGlmnet = R6Class("LearnerRegrGlmnet",
7171
upper.limits = p_uty(tags = "train")
7272
)
7373

74-
ps$set_values(family = "gaussian", use_pred_offset = TRUE)
74+
ps$set_values(family = "gaussian")
7575

7676
super$initialize(
7777
id = "regr.glmnet",

0 commit comments

Comments
 (0)