1515# '
1616# ' @inheritParams sens
1717# '
18+ # ' @param weighting A weighting to apply when computing the scores. One of:
19+ # ' `"none"`, `"linear"`, or `"quadratic"`. Linear and quadratic weighting
20+ # ' penalizes mis-predictions that are "far away" from the true value. Note
21+ # ' that distance is judged based on the ordering of the levels in `truth` and
22+ # ' `estimate`. It is recommended to provide ordered factors for `truth` and
23+ # ' `estimate` to explicitly code the ordering, but this is not required.
24+ # '
25+ # ' In the binary case, all 3 weightings produce the same value, since it is
26+ # ' only ever possible to be 1 unit away from the true value.
27+ # '
1828# ' @author Max Kuhn
29+ # ' @author Jon Harmon
30+ # '
31+ # ' @references
32+ # ' Cohen, J. (1960). "A coefficient of agreement for nominal
33+ # ' scales". _Educational and Psychological Measurement_. 20 (1): 37-46.
1934# '
20- # ' @references Cohen, J. (1960). "A coefficient of agreement for nominal
21- # ' scales". _Educational and Psychological Measurement_. 20 (1): 37-46.
35+ # ' Cohen, J. (1968). "Weighted kappa: Nominal scale agreement provision for
36+ # ' scaled disagreement or partial credit". _Psychological
37+ # ' Bulletin_. 70 (4): 213-220.
2238# '
2339# ' @export
2440# ' @examples
@@ -49,51 +65,61 @@ kap <- new_class_metric(
4965
5066# ' @export
5167# ' @rdname kap
52- kap.data.frame <- function (data , truth , estimate ,
53- na_rm = TRUE , ... ) {
68+ kap.data.frame <- function (data ,
69+ truth ,
70+ estimate ,
71+ weighting = " none" ,
72+ na_rm = TRUE ,
73+ ... ) {
5474
5575 metric_summarizer(
5676 metric_nm = " kap" ,
5777 metric_fn = kap_vec ,
5878 data = data ,
5979 truth = !! enquo(truth ),
6080 estimate = !! enquo(estimate ),
61- na_rm = na_rm
81+ na_rm = na_rm ,
82+ metric_fn_options = list (weighting = weighting )
6283 )
6384
6485}
6586
6687# ' @export
67- kap.table <- function (data , ... ) {
88+ kap.table <- function (data ,
89+ weighting = " none" ,
90+ ... ) {
6891 check_table(data )
6992 metric_tibbler(
7093 .metric = " kap" ,
7194 .estimator = finalize_estimator(data , metric_class = " kap" ),
72- .estimate = kap_table_impl(data )
95+ .estimate = kap_table_impl(data , weighting = weighting )
7396 )
7497}
7598
7699# ' @export
77- kap.matrix <- function (data , ... ) {
100+ kap.matrix <- function (data ,
101+ weighting = " none" ,
102+ ... ) {
78103 data <- as.table(data )
79- kap.table(data )
104+ kap.table(data , weighting = weighting )
80105}
81106
82107# ' @export
83108# ' @rdname kap
84- kap_vec <- function (truth , estimate , na_rm = TRUE , ... ) {
85-
109+ kap_vec <- function (truth ,
110+ estimate ,
111+ weighting = " none" ,
112+ na_rm = TRUE ,
113+ ... ) {
86114 estimator <- finalize_estimator(truth , metric_class = " kap" )
87115
88- kap_impl <- function (truth , estimate ) {
89-
116+ kap_impl <- function (truth , estimate , weighting ) {
90117 xtab <- vec2table(
91118 truth = truth ,
92119 estimate = estimate
93120 )
94121
95- kap_table_impl(xtab )
96-
122+ kap_table_impl(xtab , weighting = weighting )
97123 }
98124
99125 metric_vec_template(
@@ -102,25 +128,74 @@ kap_vec <- function(truth, estimate, na_rm = TRUE, ...) {
102128 estimate = estimate ,
103129 na_rm = na_rm ,
104130 estimator = estimator ,
105- cls = " factor"
131+ cls = " factor" ,
132+ weighting = weighting
106133 )
107-
108134}
109135
110- kap_table_impl <- function (data ) {
111- kap_binary(data )
136+ kap_table_impl <- function (data , weighting ) {
137+ full_sum <- sum(data )
138+ row_sum <- rowSums(data )
139+ col_sum <- colSums(data )
140+ expected <- outer(row_sum , col_sum ) / full_sum
141+
142+ n_levels <- nrow(data )
143+ w <- make_weighting_matrix(weighting , n_levels )
144+
145+ n_disagree <- sum(w * data )
146+ n_chance <- sum(w * expected )
147+
148+ 1 - n_disagree / n_chance
112149}
113150
114- kap_binary <- function (data ) {
151+ make_weighting_matrix <- function (weighting , n_levels ) {
152+ validate_weighting(weighting )
115153
116- n <- sum(data )
154+ if (is_no_weighting(weighting )) {
155+ # [n_levels x n_levels], 0 on diagonal, 1 on off-diagonal
156+ w <- matrix (1L , nrow = n_levels , ncol = n_levels )
157+ diag(w ) <- 0L
158+ return (w )
159+ }
117160
118- .row_sums <- rowSums(data )
119- .col_sums <- colSums(data )
161+ if (is_linear_weighting(weighting )) {
162+ power <- 1L
163+ } else {
164+ # quadratic
165+ power <- 2L
166+ }
167+
168+ # [n_levels x n_levels], 0 on diagonal, increasing weighting on off-diagonal
169+ w <- rlang :: seq2(0L , n_levels - 1L )
170+ w <- matrix (w , nrow = n_levels , ncol = n_levels )
171+ w <- abs(w - t(w )) ^ power
172+
173+ w
174+ }
175+
176+ # ------------------------------------------------------------------------------
120177
121- expected_acc <- sum( (.row_sums * .col_sums ) / n ) / n
178+ validate_weighting <- function (x ) {
179+ if (! rlang :: is_string(x )) {
180+ abort(" `weighting` must be a string." )
181+ }
122182
123- obs_acc <- accuracy_binary(data )
183+ ok <- is_no_weighting(x ) ||
184+ is_linear_weighting(x ) ||
185+ is_quadratic_weighting(x )
124186
125- (obs_acc - expected_acc ) / (1 - expected_acc )
187+ if (! ok ) {
188+ abort(" `weighting` must be 'none', 'linear', or 'quadratic'." )
189+ }
190+
191+ invisible (x )
192+ }
193+ is_no_weighting <- function (x ) {
194+ identical(x , " none" )
195+ }
196+ is_linear_weighting <- function (x ) {
197+ identical(x , " linear" )
198+ }
199+ is_quadratic_weighting <- function (x ) {
200+ identical(x , " quadratic" )
126201}
0 commit comments