Skip to content

Commit 5d1e4ab

Browse files
committed
lsp reformat
1 parent 5b0f594 commit 5d1e4ab

File tree

4 files changed

+351
-351
lines changed

4 files changed

+351
-351
lines changed

src/AutoMLPipeline.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ using AMLPipelineBase: NARemovers
1212

1313
export Machine, Learner, Transformer, Workflow, Computer
1414
export holdout, kfold, score, infer_eltype, nested_dict_to_tuples,
15-
nested_dict_set!, nested_dict_merge, create_transformer,
16-
mergedict, getiris, getprofb,
17-
skipmean, skipmedian, skipstd,
18-
aggregatorclskipmissing,
19-
find_catnum_columns,
20-
train_test_split
15+
nested_dict_set!, nested_dict_merge,
16+
mergedict, getiris, getprofb,
17+
skipmean, skipmedian, skipstd,
18+
aggregatorclskipmissing,
19+
find_catnum_columns,
20+
train_test_split
2121

2222

2323
export Baseline, Identity
@@ -50,24 +50,24 @@ export crossvalidate
5050
export skoperator
5151

5252
function skoperator(name::String; args...)::Machine
53-
sklr = keys(SKLearners.learner_dict)
54-
skpr = keys(SKPreprocessors.preprocessor_dict)
55-
if name sklr
56-
obj = SKLearner(name; args...)
57-
elseif name skpr
58-
obj = SKPreprocessor(name; args...)
59-
else
60-
skoperator()
61-
throw(ArgumentError("$name does not exist"))
62-
end
63-
return obj
53+
sklr = keys(SKLearners.learner_dict)
54+
skpr = keys(SKPreprocessors.preprocessor_dict)
55+
if name sklr
56+
obj = SKLearner(name; args...)
57+
elseif name skpr
58+
obj = SKPreprocessor(name; args...)
59+
else
60+
skoperator()
61+
throw(ArgumentError("$name does not exist"))
62+
end
63+
return obj
6464
end
6565

6666
function skoperator()
67-
sklr = keys(SKLearners.learner_dict)
68-
skpr = keys(SKPreprocessors.preprocessor_dict)
69-
println("Please choose among these pipeline elements:")
70-
println([sklr..., skpr...])
67+
sklr = keys(SKLearners.learner_dict)
68+
skpr = keys(SKPreprocessors.preprocessor_dict)
69+
println("Please choose among these pipeline elements:")
70+
println([sklr..., skpr...])
7171
end
7272

7373
end # module

src/skcrossvalidator.jl

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,57 +16,57 @@ const metric_dict = Dict{String,PYC.Py}()
1616
const SKM = PYC.pynew()
1717

1818
function __init__()
19-
PYC.pycopy!(SKM, PYC.pyimport("sklearn.metrics"))
20-
21-
metric_dict["roc_auc_score"] = SKM.roc_auc_score
22-
metric_dict["accuracy_score"] = SKM.accuracy_score
23-
metric_dict["auc"] = SKM.auc
24-
metric_dict["average_precision_score"] = SKM.average_precision_score
25-
metric_dict["balanced_accuracy_score"] = SKM.balanced_accuracy_score
26-
metric_dict["brier_score_loss"] = SKM.brier_score_loss
27-
metric_dict["classification_report"] = SKM.classification_report
28-
metric_dict["cohen_kappa_score"] = SKM.cohen_kappa_score
29-
metric_dict["confusion_matrix"] = SKM.confusion_matrix
30-
metric_dict["f1_score"] = SKM.f1_score
31-
metric_dict["fbeta_score"] = SKM.fbeta_score
32-
metric_dict["hamming_loss"] = SKM.hamming_loss
33-
metric_dict["hinge_loss"] = SKM.hinge_loss
34-
metric_dict["log_loss"] = SKM.log_loss
35-
metric_dict["matthews_corrcoef"] = SKM.matthews_corrcoef
36-
metric_dict["multilabel_confusion_matrix"] = SKM.multilabel_confusion_matrix
37-
metric_dict["precision_recall_curve"] = SKM.precision_recall_curve
38-
metric_dict["precision_recall_fscore_support"] = SKM.precision_recall_fscore_support
39-
metric_dict["precision_score"] = SKM.precision_score
40-
metric_dict["recall_score"] = SKM.recall_score
41-
metric_dict["roc_auc_score"] = SKM.roc_auc_score
42-
metric_dict["roc_curve"] = SKM.roc_curve
43-
metric_dict["jaccard_score"] = SKM.jaccard_score
44-
metric_dict["zero_one_loss"] = SKM.zero_one_loss
45-
# regression
46-
metric_dict["mean_squared_error"] = SKM.mean_squared_error
47-
metric_dict["mean_squared_log_error"] = SKM.mean_squared_log_error
48-
metric_dict["mean_absolute_error"] = SKM.mean_absolute_error
49-
metric_dict["median_absolute_error"] = SKM.median_absolute_error
50-
metric_dict["r2_score"] = SKM.r2_score
51-
metric_dict["max_error"] = SKM.max_error
52-
metric_dict["mean_poisson_deviance"] = SKM.mean_poisson_deviance
53-
metric_dict["mean_gamma_deviance"] = SKM.mean_gamma_deviance
54-
metric_dict["mean_tweedie_deviance"] = SKM.mean_tweedie_deviance
55-
metric_dict["explained_variance_score"] = SKM.explained_variance_score
19+
PYC.pycopy!(SKM, PYC.pyimport("sklearn.metrics"))
20+
21+
metric_dict["roc_auc_score"] = SKM.roc_auc_score
22+
metric_dict["accuracy_score"] = SKM.accuracy_score
23+
metric_dict["auc"] = SKM.auc
24+
metric_dict["average_precision_score"] = SKM.average_precision_score
25+
metric_dict["balanced_accuracy_score"] = SKM.balanced_accuracy_score
26+
metric_dict["brier_score_loss"] = SKM.brier_score_loss
27+
metric_dict["classification_report"] = SKM.classification_report
28+
metric_dict["cohen_kappa_score"] = SKM.cohen_kappa_score
29+
metric_dict["confusion_matrix"] = SKM.confusion_matrix
30+
metric_dict["f1_score"] = SKM.f1_score
31+
metric_dict["fbeta_score"] = SKM.fbeta_score
32+
metric_dict["hamming_loss"] = SKM.hamming_loss
33+
metric_dict["hinge_loss"] = SKM.hinge_loss
34+
metric_dict["log_loss"] = SKM.log_loss
35+
metric_dict["matthews_corrcoef"] = SKM.matthews_corrcoef
36+
metric_dict["multilabel_confusion_matrix"] = SKM.multilabel_confusion_matrix
37+
metric_dict["precision_recall_curve"] = SKM.precision_recall_curve
38+
metric_dict["precision_recall_fscore_support"] = SKM.precision_recall_fscore_support
39+
metric_dict["precision_score"] = SKM.precision_score
40+
metric_dict["recall_score"] = SKM.recall_score
41+
metric_dict["roc_auc_score"] = SKM.roc_auc_score
42+
metric_dict["roc_curve"] = SKM.roc_curve
43+
metric_dict["jaccard_score"] = SKM.jaccard_score
44+
metric_dict["zero_one_loss"] = SKM.zero_one_loss
45+
# regression
46+
metric_dict["mean_squared_error"] = SKM.mean_squared_error
47+
metric_dict["mean_squared_log_error"] = SKM.mean_squared_log_error
48+
metric_dict["mean_absolute_error"] = SKM.mean_absolute_error
49+
metric_dict["median_absolute_error"] = SKM.median_absolute_error
50+
metric_dict["r2_score"] = SKM.r2_score
51+
metric_dict["max_error"] = SKM.max_error
52+
metric_dict["mean_poisson_deviance"] = SKM.mean_poisson_deviance
53+
metric_dict["mean_gamma_deviance"] = SKM.mean_gamma_deviance
54+
metric_dict["mean_tweedie_deviance"] = SKM.mean_tweedie_deviance
55+
metric_dict["explained_variance_score"] = SKM.explained_variance_score
5656
end
5757

5858
function checkfun(sfunc::String)
59-
if !(sfunc in keys(metric_dict))
60-
println("$sfunc metric is not supported")
61-
println("metric: ", keys(metric_dict))
62-
error("Metric keyword error")
63-
end
59+
if !(sfunc in keys(metric_dict))
60+
println("$sfunc metric is not supported")
61+
println("metric: ", keys(metric_dict))
62+
error("Metric keyword error")
63+
end
6464
end
6565

6666
"""
6767
crossvalidate(pl::Machine,X::DataFrame,Y::Vector,sfunc::String="balanced_accuracy_score";nfolds=10,verbose=true)
6868
69-
Runs K-fold cross-validation using balanced accuracy as the default. It support the
69+
Runs K-fold cross-validation using balanced accuracy as the default. It support the
7070
following metrics for classification:
7171
- "accuracy_score"
7272
- "balanced_accuracy_score"
@@ -88,38 +88,38 @@ and the following metrics for regression:
8888
- "explained_variance_score"
8989
"""
9090
function crossvalidate(pl::Machine, X::DataFrame, Y::Vector,
91-
sfunc::String; nfolds=10, verbose::Bool=true)
91+
sfunc::String; nfolds=10, verbose::Bool=true)
9292

93-
YC = Y
94-
if !(eltype(YC) <: Real)
95-
YC = Y |> Vector{String}
96-
end
93+
YC = Y
94+
if !(eltype(YC) <: Real)
95+
YC = Y |> Vector{String}
96+
end
9797

98-
checkfun(sfunc)
99-
pfunc = metric_dict[sfunc]
100-
metric(a, b) = pfunc(a, b) |> (x -> PYC.pyconvert(Float64, x))
101-
crossvalidate(pl, X, YC, metric, nfolds, verbose)
98+
checkfun(sfunc)
99+
pfunc = metric_dict[sfunc]
100+
metric(a, b) = pfunc(a, b) |> (x -> PYC.pyconvert(Float64, x))
101+
crossvalidate(pl, X, YC, metric, nfolds, verbose)
102102
end
103103

104104
function crossvalidate(pl::Machine, X::DataFrame, Y::Vector, sfunc::String, nfolds::Int)
105-
crossvalidate(pl, X, Y, sfunc; nfolds)
105+
crossvalidate(pl, X, Y, sfunc; nfolds)
106106
end
107107

108108
function crossvalidate(pl::Machine, X::DataFrame, Y::Vector, sfunc::String, verbose::Bool)
109-
crossvalidate(pl, X, Y, sfunc; verbose)
109+
crossvalidate(pl, X, Y, sfunc; verbose)
110110
end
111111

112112
function crossvalidate(pl::Machine, X::DataFrame, Y::Vector,
113-
sfunc::String, nfolds::Int, verbose::Bool)
114-
crossvalidate(pl, X, Y, sfunc; nfolds, verbose)
113+
sfunc::String, nfolds::Int, verbose::Bool)
114+
crossvalidate(pl, X, Y, sfunc; nfolds, verbose)
115115
end
116116

117117
function crossvalidate(pl::Machine, X::DataFrame, Y::Vector,
118-
sfunc::String, averagetype::String; nfolds=10, verbose::Bool=true)
119-
checkfun(sfunc)
120-
pfunc = metric_dict[sfunc]
121-
metric(a, b) = pfunc(a, b, average=averagetype) |> (x -> PYC.pyconvert(Float64, x))
122-
crossvalidate(pl, X, Y, metric, nfolds, verbose)
118+
sfunc::String, averagetype::String; nfolds=10, verbose::Bool=true)
119+
checkfun(sfunc)
120+
pfunc = metric_dict[sfunc]
121+
metric(a, b) = pfunc(a, b, average=averagetype) |> (x -> PYC.pyconvert(Float64, x))
122+
crossvalidate(pl, X, Y, metric, nfolds, verbose)
123123
end
124124

125125

0 commit comments

Comments
 (0)