Skip to content

Commit 7771311

Browse files
added first tests for PipeOpClassWeightsEx
1 parent 387f54a commit 7771311

File tree

3 files changed

+59
-6
lines changed

3 files changed

+59
-6
lines changed

R/PipeOpClassWeights.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ PipeOpClassWeights = R6Class("PipeOpClassWeights",
9797
),
9898
private = list(
9999
.train_task = function(task) {
100-
browser()
100+
101101
pv = self$param_set$get_values(tags = "train")
102102

103103
if ("twoclass" %nin% task$properties) {

R/PipeOpClassWeightsEx.R

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,19 @@ PipeOpClassWeightsEx = R6Class("PipeOpClassWeightsEx",
101101
.train_task = function(task) {
102102
pv = self$param_set$get_values(tags = "train")
103103

104+
if (is.null(pv$weight_type) ||
105+
is.null(pv$weight_method) ||
106+
(pv$weight_method == "explicit" && is.null(pv$mapping))) {
107+
return(task)
108+
}
109+
104110
weightcolname = ".WEIGHTS"
105111
if (weightcolname %in% unlist(task$col_roles)) {
106112
stopf("Weight column '%s' is already in the Task", weightcolname)
107113
}
108114

109115
truth = task$truth()
110116

111-
if (is.null(pv$weight_type)) {
112-
return(task)
113-
}
114-
115117
class_frequency = table(truth) / length(truth)
116118
class_names = names(class_frequency)
117119

@@ -122,7 +124,8 @@ PipeOpClassWeightsEx = R6Class("PipeOpClassWeightsEx",
122124
"explicit" = pv$mapping
123125
)
124126

125-
wcol = setnames(data.table(weights_by_class[truth])[, "N"], weightcolname)
127+
weights_table = data.table(weights_by_class[truth])
128+
wcol = setnames(as.data.table(weights_table[[ncol(weights_table)]]), weightcolname)
126129
task$cbind(wcol)
127130
task$col_roles$feature = setdiff(task$col_roles$feature, weightcolname)
128131

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
test_that("PipeOpClassWeightsEx - basic properties", {
3+
op = PipeOpClassWeightsEx$new()
4+
task = mlr_tasks$get("german_credit")
5+
expect_pipeop(op)
6+
train_pipeop(op, inputs = list(task))
7+
predict_pipeop(op, inputs = list(task))
8+
9+
expect_datapreproc_pipeop_class(PipeOpClassWeights, task = task,
10+
predict_like_train = FALSE)
11+
})
12+
13+
test_that("PipeOpClassWeightsEx - error for Tasks without weights property, #937", {
14+
skip_if_not_installed("mlr3learners")
15+
skip_if_not_installed("MASS")
16+
17+
set.seed(1234)
18+
task = as_task_classif(data.table(
19+
y = factor(rep(c("A", "B", "A", "C"), 4)),
20+
x = runif(16)
21+
), target = "y")
22+
23+
# no error: Learner has weights property
24+
gr = po("classweightsex", param_vals = list(mapping = c("A" = 0.6, "B" = 0.3, "C" = 0.1))) %>>% lrn("classif.featureless")
25+
expect_no_error(gr$train(task))
26+
27+
# error: Learner does not have weights property
28+
gr = po("classweightsex", param_vals = list(mapping = c("A" = 0.6, "B" = 0.3, "C" = 0.1))) %>>% lrn("classif.lda")
29+
expect_error(gr$train(task), ".*Learner does not support weights.*")
30+
31+
# no error: use_weights is set to "ignore"
32+
gr = po("classweightsex", param_vals = list(mapping = c("A" = 0.6, "B" = 0.4, "C" = 0.1))) %>>% lrn("classif.lda", use_weights = "ignore")
33+
expect_no_error(gr$train(task))
34+
35+
})
36+
37+
test_that("PipeOpClassWeightsEx", {
38+
39+
task = mlr_tasks$get("iris")
40+
41+
# Method inverse_class_frequency
42+
poicf = po("classweightsex", param_vals = list(weight_method = "inverse_class_frequency"))
43+
nt = poicf$train(list(task))[[1L]]
44+
expect_equal(nt$data(), task$data())
45+
46+
47+
# manual_weights = as.data.table(1 / table(task$data()$Species))
48+
# weights = if ("weights_learner" %in% names(nt)) "weights_learner" else "weights"
49+
# expect_equal(nt[[weights]]$weight, ifelse(nt$truth(nt[[weights]]$row_ids) == "neg", 1, 3))
50+
})

0 commit comments

Comments
 (0)