Skip to content

Commit f50cca1

Browse files
extended tests
1 parent 3563b34 commit f50cca1

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

tests/testthat/test_pipeop_classweightsex.R

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ test_that("PipeOpClassWeightsEx - error for Tasks without weights property, #937
3636

3737
test_that("PipeOpClassWeightsEx", {
3838

39-
task = mlr_tasks$get("iris")
39+
task = mlr_tasks$get("penguins")
4040

4141
# Method inverse_class_frequency
4242
poicf = po("classweightsex", param_vals = list(weight_method = "inverse_class_frequency"))
@@ -54,6 +54,38 @@ test_that("PipeOpClassWeightsEx", {
5454

5555
expect_equal(computed_weights[["weight"]], as.numeric(unclass(manual_weights)))
5656

57+
# Method inverse square root of class frequency
58+
poisf = po("classweightsex", param_vals = list(weight_method = "inverse_square_root_of_frequency"))
59+
nt = poisf$train(list(task))[[1L]]
60+
expect_equal(nt$data(), task$data())
61+
62+
freq = prop.table(table(task$truth()))
63+
manual_weights = 1 / sqrt(freq[task$truth()])
64+
65+
if ("weights_learner" %in% names(nt$col_roles)) {
66+
computed_weights = nt$weights_learner
67+
} else {
68+
computed_weights = nt$weights
69+
}
70+
71+
expect_equal(computed_weights[["weight"]], as.numeric(unclass(manual_weights)))
72+
73+
# Method median frequency balancing
74+
pomfb = po("classweightsex", param_vals = list(weight_method = "median_frequency_balancing"))
75+
nt = pomfb$train(list(task))[[1L]]
76+
expect_equal(nt$data(), task$data())
77+
78+
#freq = prop.table(table(task$truth()))
79+
#manual_weights = median(freq) / freq
80+
#manual_weights = 1 / sqrt(freq[task$truth()])
81+
82+
#if ("weights_learner" %in% names(nt$col_roles)) {
83+
# computed_weights = nt$weights_learner
84+
#} else {
85+
# computed_weights = nt$weights
86+
#}
87+
88+
#expect_equal(computed_weights[["weight"]], as.numeric(unclass(manual_weights)))
5789

5890
# weights = if ("weights_learner" %in% names(nt)) "weights_learner" else "weights"
5991
# expect_equal(nt[[weights]]$weight, ifelse(nt$truth(nt[[weights]]$row_ids) == "neg", 1, 3))

0 commit comments

Comments
 (0)