@@ -16,57 +16,57 @@ const metric_dict = Dict{String,PYC.Py}()
1616const SKM = PYC. pynew ()
1717
1818function __init__ ()
19- PYC. pycopy! (SKM, PYC. pyimport (" sklearn.metrics" ))
20-
21- metric_dict[" roc_auc_score" ] = SKM. roc_auc_score
22- metric_dict[" accuracy_score" ] = SKM. accuracy_score
23- metric_dict[" auc" ] = SKM. auc
24- metric_dict[" average_precision_score" ] = SKM. average_precision_score
25- metric_dict[" balanced_accuracy_score" ] = SKM. balanced_accuracy_score
26- metric_dict[" brier_score_loss" ] = SKM. brier_score_loss
27- metric_dict[" classification_report" ] = SKM. classification_report
28- metric_dict[" cohen_kappa_score" ] = SKM. cohen_kappa_score
29- metric_dict[" confusion_matrix" ] = SKM. confusion_matrix
30- metric_dict[" f1_score" ] = SKM. f1_score
31- metric_dict[" fbeta_score" ] = SKM. fbeta_score
32- metric_dict[" hamming_loss" ] = SKM. hamming_loss
33- metric_dict[" hinge_loss" ] = SKM. hinge_loss
34- metric_dict[" log_loss" ] = SKM. log_loss
35- metric_dict[" matthews_corrcoef" ] = SKM. matthews_corrcoef
36- metric_dict[" multilabel_confusion_matrix" ] = SKM. multilabel_confusion_matrix
37- metric_dict[" precision_recall_curve" ] = SKM. precision_recall_curve
38- metric_dict[" precision_recall_fscore_support" ] = SKM. precision_recall_fscore_support
39- metric_dict[" precision_score" ] = SKM. precision_score
40- metric_dict[" recall_score" ] = SKM. recall_score
41- metric_dict[" roc_auc_score" ] = SKM. roc_auc_score
42- metric_dict[" roc_curve" ] = SKM. roc_curve
43- metric_dict[" jaccard_score" ] = SKM. jaccard_score
44- metric_dict[" zero_one_loss" ] = SKM. zero_one_loss
45- # regression
46- metric_dict[" mean_squared_error" ] = SKM. mean_squared_error
47- metric_dict[" mean_squared_log_error" ] = SKM. mean_squared_log_error
48- metric_dict[" mean_absolute_error" ] = SKM. mean_absolute_error
49- metric_dict[" median_absolute_error" ] = SKM. median_absolute_error
50- metric_dict[" r2_score" ] = SKM. r2_score
51- metric_dict[" max_error" ] = SKM. max_error
52- metric_dict[" mean_poisson_deviance" ] = SKM. mean_poisson_deviance
53- metric_dict[" mean_gamma_deviance" ] = SKM. mean_gamma_deviance
54- metric_dict[" mean_tweedie_deviance" ] = SKM. mean_tweedie_deviance
55- metric_dict[" explained_variance_score" ] = SKM. explained_variance_score
19+ PYC. pycopy! (SKM, PYC. pyimport (" sklearn.metrics" ))
20+
21+ metric_dict[" roc_auc_score" ] = SKM. roc_auc_score
22+ metric_dict[" accuracy_score" ] = SKM. accuracy_score
23+ metric_dict[" auc" ] = SKM. auc
24+ metric_dict[" average_precision_score" ] = SKM. average_precision_score
25+ metric_dict[" balanced_accuracy_score" ] = SKM. balanced_accuracy_score
26+ metric_dict[" brier_score_loss" ] = SKM. brier_score_loss
27+ metric_dict[" classification_report" ] = SKM. classification_report
28+ metric_dict[" cohen_kappa_score" ] = SKM. cohen_kappa_score
29+ metric_dict[" confusion_matrix" ] = SKM. confusion_matrix
30+ metric_dict[" f1_score" ] = SKM. f1_score
31+ metric_dict[" fbeta_score" ] = SKM. fbeta_score
32+ metric_dict[" hamming_loss" ] = SKM. hamming_loss
33+ metric_dict[" hinge_loss" ] = SKM. hinge_loss
34+ metric_dict[" log_loss" ] = SKM. log_loss
35+ metric_dict[" matthews_corrcoef" ] = SKM. matthews_corrcoef
36+ metric_dict[" multilabel_confusion_matrix" ] = SKM. multilabel_confusion_matrix
37+ metric_dict[" precision_recall_curve" ] = SKM. precision_recall_curve
38+ metric_dict[" precision_recall_fscore_support" ] = SKM. precision_recall_fscore_support
39+ metric_dict[" precision_score" ] = SKM. precision_score
40+ metric_dict[" recall_score" ] = SKM. recall_score
41+ metric_dict[" roc_auc_score" ] = SKM. roc_auc_score
42+ metric_dict[" roc_curve" ] = SKM. roc_curve
43+ metric_dict[" jaccard_score" ] = SKM. jaccard_score
44+ metric_dict[" zero_one_loss" ] = SKM. zero_one_loss
45+ # regression
46+ metric_dict[" mean_squared_error" ] = SKM. mean_squared_error
47+ metric_dict[" mean_squared_log_error" ] = SKM. mean_squared_log_error
48+ metric_dict[" mean_absolute_error" ] = SKM. mean_absolute_error
49+ metric_dict[" median_absolute_error" ] = SKM. median_absolute_error
50+ metric_dict[" r2_score" ] = SKM. r2_score
51+ metric_dict[" max_error" ] = SKM. max_error
52+ metric_dict[" mean_poisson_deviance" ] = SKM. mean_poisson_deviance
53+ metric_dict[" mean_gamma_deviance" ] = SKM. mean_gamma_deviance
54+ metric_dict[" mean_tweedie_deviance" ] = SKM. mean_tweedie_deviance
55+ metric_dict[" explained_variance_score" ] = SKM. explained_variance_score
5656end
5757
5858function checkfun (sfunc:: String )
59- if ! (sfunc in keys (metric_dict))
60- println (" $sfunc metric is not supported" )
61- println (" metric: " , keys (metric_dict))
62- error (" Metric keyword error" )
63- end
59+ if ! (sfunc in keys (metric_dict))
60+ println (" $sfunc metric is not supported" )
61+ println (" metric: " , keys (metric_dict))
62+ error (" Metric keyword error" )
63+ end
6464end
6565
6666"""
6767 crossvalidate(pl::Machine,X::DataFrame,Y::Vector,sfunc::String="balanced_accuracy_score";nfolds=10,verbose=true)
6868
69- Runs K-fold cross-validation using balanced accuracy as the default. It support the
69+ Runs K-fold cross-validation using balanced accuracy as the default. It support the
7070following metrics for classification:
7171- "accuracy_score"
7272- "balanced_accuracy_score"
@@ -88,38 +88,38 @@ and the following metrics for regression:
8888- "explained_variance_score"
8989"""
9090function crossvalidate (pl:: Machine , X:: DataFrame , Y:: Vector ,
91- sfunc:: String ; nfolds= 10 , verbose:: Bool = true )
91+ sfunc:: String ; nfolds= 10 , verbose:: Bool = true )
9292
93- YC = Y
94- if ! (eltype (YC) <: Real )
95- YC = Y |> Vector{String}
96- end
93+ YC = Y
94+ if ! (eltype (YC) <: Real )
95+ YC = Y |> Vector{String}
96+ end
9797
98- checkfun (sfunc)
99- pfunc = metric_dict[sfunc]
100- metric (a, b) = pfunc (a, b) |> (x -> PYC. pyconvert (Float64, x))
101- crossvalidate (pl, X, YC, metric, nfolds, verbose)
98+ checkfun (sfunc)
99+ pfunc = metric_dict[sfunc]
100+ metric (a, b) = pfunc (a, b) |> (x -> PYC. pyconvert (Float64, x))
101+ crossvalidate (pl, X, YC, metric, nfolds, verbose)
102102end
103103
104104function crossvalidate (pl:: Machine , X:: DataFrame , Y:: Vector , sfunc:: String , nfolds:: Int )
105- crossvalidate (pl, X, Y, sfunc; nfolds)
105+ crossvalidate (pl, X, Y, sfunc; nfolds)
106106end
107107
108108function crossvalidate (pl:: Machine , X:: DataFrame , Y:: Vector , sfunc:: String , verbose:: Bool )
109- crossvalidate (pl, X, Y, sfunc; verbose)
109+ crossvalidate (pl, X, Y, sfunc; verbose)
110110end
111111
112112function crossvalidate (pl:: Machine , X:: DataFrame , Y:: Vector ,
113- sfunc:: String , nfolds:: Int , verbose:: Bool )
114- crossvalidate (pl, X, Y, sfunc; nfolds, verbose)
113+ sfunc:: String , nfolds:: Int , verbose:: Bool )
114+ crossvalidate (pl, X, Y, sfunc; nfolds, verbose)
115115end
116116
117117function crossvalidate (pl:: Machine , X:: DataFrame , Y:: Vector ,
118- sfunc:: String , averagetype:: String ; nfolds= 10 , verbose:: Bool = true )
119- checkfun (sfunc)
120- pfunc = metric_dict[sfunc]
121- metric (a, b) = pfunc (a, b, average= averagetype) |> (x -> PYC. pyconvert (Float64, x))
122- crossvalidate (pl, X, Y, metric, nfolds, verbose)
118+ sfunc:: String , averagetype:: String ; nfolds= 10 , verbose:: Bool = true )
119+ checkfun (sfunc)
120+ pfunc = metric_dict[sfunc]
121+ metric (a, b) = pfunc (a, b, average= averagetype) |> (x -> PYC. pyconvert (Float64, x))
122+ crossvalidate (pl, X, Y, metric, nfolds, verbose)
123123end
124124
125125
0 commit comments