@@ -36,7 +36,7 @@ test_that("PipeOpClassWeightsEx - error for Tasks without weights property, #937
3636
3737test_that(" PipeOpClassWeightsEx"  , {
3838
39-   task  =  mlr_tasks $ get(" iris "  )
39+   task  =  mlr_tasks $ get(" penguins "  )
4040
4141  #  Method inverse_class_frequency
4242  poicf  =  po(" classweightsex"  , param_vals  =  list (weight_method  =  " inverse_class_frequency"  ))
@@ -54,6 +54,38 @@ test_that("PipeOpClassWeightsEx", {
5454
5555  expect_equal(computed_weights [[" weight"  ]], as.numeric(unclass(manual_weights )))
5656
57+   #  Method inverse square root of class frequency
58+   poisf  =  po(" classweightsex"  , param_vals  =  list (weight_method  =  " inverse_square_root_of_frequency"  ))
59+   nt  =  poisf $ train(list (task ))[[1L ]]
60+   expect_equal(nt $ data(), task $ data())
61+ 
62+   freq  =  prop.table(table(task $ truth()))
63+   manual_weights  =  1  /  sqrt(freq [task $ truth()])
64+ 
65+   if  (" weights_learner"   %in%  names(nt $ col_roles )) {
66+     computed_weights  =  nt $ weights_learner 
67+   } else  {
68+     computed_weights  =  nt $ weights 
69+   }
70+ 
71+   expect_equal(computed_weights [[" weight"  ]], as.numeric(unclass(manual_weights )))
72+ 
73+   #  Method median frequency balancing
74+   pomfb  =  po(" classweightsex"  , param_vals  =  list (weight_method  =  " median_frequency_balancing"  ))
75+   nt  =  pomfb $ train(list (task ))[[1L ]]
76+   expect_equal(nt $ data(), task $ data())
77+ 
78+   # freq = prop.table(table(task$truth()))
79+   # manual_weights = median(freq) / freq
80+   # manual_weights = 1 / sqrt(freq[task$truth()])
81+ 
82+   # if ("weights_learner" %in% names(nt$col_roles)) {
83+   #   computed_weights = nt$weights_learner
84+   # } else {
85+   #   computed_weights = nt$weights
86+   # }
87+ 
88+   # expect_equal(computed_weights[["weight"]], as.numeric(unclass(manual_weights)))
5789
5890  #  weights = if ("weights_learner" %in% names(nt)) "weights_learner" else "weights"
5991  #  expect_equal(nt[[weights]]$weight, ifelse(nt$truth(nt[[weights]]$row_ids) == "neg", 1, 3))
0 commit comments