Skip to content

Commit 5257b76

Browse files
committed
style: reformat
1 parent d1856ac commit 5257b76

File tree

16 files changed

+111
-109
lines changed

16 files changed

+111
-109
lines changed

cyeva/core/base.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
calc_threshold_mae,
2727
calc_multiclass_accuracy_ratio,
2828
calc_multiclass_hanssen_kuipers_score,
29-
calc_multiclass_heidke_skill_score
29+
calc_multiclass_heidke_skill_score,
3030
)
3131

3232

@@ -36,7 +36,6 @@ class Comparison:
3636
def __init__(
3737
self, observation: Union[np.ndarray, list], forecast: Union[np.ndarray, list]
3838
):
39-
4039
if isinstance(observation, Quantity):
4140
observation = observation.magnitude
4241
if isinstance(forecast, Quantity):
@@ -216,7 +215,6 @@ def calc_multiclass_accuracy_ratio(
216215
*args,
217216
**kwargs
218217
) -> float:
219-
220218
if observation is None:
221219
observation = self.observation
222220
if forecast is None:
@@ -232,14 +230,13 @@ def calc_multiclass_hanssen_kuipers_score(
232230
*args,
233231
**kwargs
234232
) -> float:
235-
236233
if observation is None:
237234
observation = self.observation
238235
if forecast is None:
239236
forecast = self.forecast
240237

241238
return calc_multiclass_hanssen_kuipers_score(observation, forecast)
242-
239+
243240
@result_round_digit(4)
244241
def calc_multiclass_heidke_skill_score(
245242
self,
@@ -248,7 +245,6 @@ def calc_multiclass_heidke_skill_score(
248245
*args,
249246
**kwargs
250247
) -> float:
251-
252248
if observation is None:
253249
observation = self.observation
254250
if forecast is None:
@@ -295,7 +291,6 @@ def calc_threshold_accuracy_ratio(
295291
*args,
296292
**kwargs
297293
) -> float:
298-
299294
if observation is None:
300295
observation = self.observation
301296
if forecast is None:
@@ -316,7 +311,6 @@ def calc_threshold_hit_ratio(
316311
*args,
317312
**kwargs
318313
) -> float:
319-
320314
if observation is None:
321315
observation = self.observation
322316
if forecast is None:
@@ -337,7 +331,6 @@ def calc_threshold_miss_ratio(
337331
*args,
338332
**kwargs
339333
) -> float:
340-
341334
if observation is None:
342335
observation = self.observation
343336
if forecast is None:
@@ -358,7 +351,6 @@ def calc_threshold_false_alarm_ratio(
358351
*args,
359352
**kwargs
360353
) -> float:
361-
362354
if observation is None:
363355
observation = self.observation
364356
if forecast is None:
@@ -379,7 +371,6 @@ def calc_threshold_bias_score(
379371
*args,
380372
**kwargs
381373
) -> float:
382-
383374
if observation is None:
384375
observation = self.observation
385376
if forecast is None:
@@ -400,7 +391,6 @@ def calc_threshold_ts(
400391
*args,
401392
**kwargs
402393
) -> float:
403-
404394
if observation is None:
405395
observation = self.observation
406396
if forecast is None:
@@ -421,7 +411,6 @@ def calc_threshold_mae(
421411
*args,
422412
**kwargs
423413
) -> float:
424-
425414
if observation is None:
426415
observation = self.observation
427416
if forecast is None:

cyeva/core/precip.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,12 @@ def __init__(
209209
self.kind = kind
210210
self.unit = unit
211211
self.lev = lev
212-
self.observation = (self.observation * UNITS.parse_expression(unit)).to("mm").magnitude
213-
self.forecast = (self.forecast * UNITS.parse_expression(unit)).to("mm").magnitude
212+
self.observation = (
213+
(self.observation * UNITS.parse_expression(unit)).to("mm").magnitude
214+
)
215+
self.forecast = (
216+
(self.forecast * UNITS.parse_expression(unit)).to("mm").magnitude
217+
)
214218
self.df = pd.DataFrame(
215219
{
216220
"observation": self.observation,

cyeva/core/statistic.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def calc_binary_quadrant_values(
4848

4949
return hits, misses, false_alarms, correct_rejects, total
5050

51+
5152
@assert_length
5253
@drop_nan
5354
def calc_multiclass_confusion_matrix(
@@ -80,11 +81,17 @@ class K n(F_K,O_1) n(F_K,O_2) n(F_K,O_K)
8081
cates = np.unique(np.concatenate([np.unique(observation), np.unique(forecast)]))
8182
confusion_matrix_list = []
8283
for obs_cate_, fcst_cate_ in product(cates, cates):
83-
count_cate_ = Counter((observation==obs_cate_) & (forecast==fcst_cate_))[True]
84+
count_cate_ = Counter((observation == obs_cate_) & (forecast == fcst_cate_))[
85+
True
86+
]
8487
confusion_matrix_list.append([obs_cate_, fcst_cate_, count_cate_])
8588

86-
confusion_matrix = pd.DataFrame(np.array(confusion_matrix_list), columns=['observation','forecast', 'count'])
87-
confusion_matrix = confusion_matrix.pivot_table('count', index='forecast',columns='observation',aggfunc='sum').astype(int)
89+
confusion_matrix = pd.DataFrame(
90+
np.array(confusion_matrix_list), columns=["observation", "forecast", "count"]
91+
)
92+
confusion_matrix = confusion_matrix.pivot_table(
93+
"count", index="forecast", columns="observation", aggfunc="sum"
94+
).astype(int)
8895

8996
assert len(observation) == np.sum(confusion_matrix.values)
9097

@@ -104,98 +111,101 @@ def calc_multiclass_accuracy_ratio(
104111
Args:
105112
observation (Union[list, np.ndarray]): Multiclass observation data array
106113
that consist of class labels.
107-
forecast (Union[list, np.ndarray]): Multiclass forecast data array
114+
forecast (Union[list, np.ndarray]): Multiclass forecast data array
108115
that consist of class labels.
109116
110117
Returns:
111118
float: The accuracy(%) of multiclass forecast. Perfect score 100.
112119
"""
113120

114-
confusion_matrix = calc_multiclass_confusion_matrix(
115-
observation, forecast
116-
)
121+
confusion_matrix = calc_multiclass_confusion_matrix(observation, forecast)
117122
# compute the sum of hits of all categories
118123
all_hits = np.sum(confusion_matrix.values.diagonal())
119124
total = len(observation)
120125

121-
return (all_hits / total) * 100
126+
return (all_hits / total) * 100
127+
122128

123129
@assert_length
124130
@fix_zero_division
125131
@drop_nan
126132
def calc_multiclass_heidke_skill_score(
127133
observation: Union[list, np.ndarray], forecast: Union[list, np.ndarray]
128134
) -> float:
129-
"""calculate the Heidke Skill Score (HSS), which measures the
130-
fraction of correct forecasts after eliminating those forecasts
131-
which would be correct due purely to random chance.
135+
"""calculate the Heidke Skill Score (HSS), which measures the
136+
fraction of correct forecasts after eliminating those forecasts
137+
which would be correct due purely to random chance.
132138
133-
HSS = \frac {\frac {1} {Total} \sum\limits_{i=1}^{K} n(F_i,O_i) -
134-
\frac {1} {Total^2} \sum\limits_{i=1}^{K} N(F_i)N(O_i) }
139+
HSS = \frac {\frac {1} {Total} \sum\limits_{i=1}^{K} n(F_i,O_i) -
140+
\frac {1} {Total^2} \sum\limits_{i=1}^{K} N(F_i)N(O_i) }
135141
{1 - \frac {1} {Total^2} \sum\limits_{i=1}^{K} N(F_i)*N(O_i)}
136142
137143
Args:
138144
observation (Union[list, np.ndarray]): Multiclass observation data array
139145
that consist of class labels.
140-
forecast (Union[list, np.ndarray]): Multiclass forecast data array
146+
forecast (Union[list, np.ndarray]): Multiclass forecast data array
141147
that consist of class labels.
142148
143149
Returns:
144150
float: HSS score. Perfect score 1.
145151
"""
146152

147-
confusion_matrix = calc_multiclass_confusion_matrix(
148-
observation, forecast
149-
)
153+
confusion_matrix = calc_multiclass_confusion_matrix(observation, forecast)
150154
total = len(observation)
151155

152156
# compute HSS score
153-
acc_ = np.sum(confusion_matrix.values.diagonal()) / total
154-
reference_acc_ = np.sum(confusion_matrix.sum(axis=0).values * confusion_matrix.sum(axis=1).values) / (total**2)
157+
acc_ = np.sum(confusion_matrix.values.diagonal()) / total
158+
reference_acc_ = np.sum(
159+
confusion_matrix.sum(axis=0).values * confusion_matrix.sum(axis=1).values
160+
) / (total**2)
155161
perfect_acc_ = 1
156-
hss_score_ = ( acc_ - reference_acc_ ) / (perfect_acc_ - reference_acc_)
157-
162+
hss_score_ = (acc_ - reference_acc_) / (perfect_acc_ - reference_acc_)
163+
158164
return hss_score_
159165

166+
160167
@assert_length
161168
@fix_zero_division
162169
@drop_nan
163170
def calc_multiclass_hanssen_kuipers_score(
164171
observation: Union[list, np.ndarray], forecast: Union[list, np.ndarray]
165172
) -> float:
166-
"""calculate the Hanssen and Kuipers Score (HSS), which is
167-
similar to the Heidke skill score (above), except that in
168-
the denominator the fraction of correct forecasts due to
173+
"""calculate the Hanssen and Kuipers Score (HSS), which is
174+
similar to the Heidke skill score (above), except that in
175+
the denominator the fraction of correct forecasts due to
169176
random chance is for an unbiased forecast.
170177
171-
HK = \frac {\frac {1} {Total} \sum\limits_{i=1}^{K} n(F_i,O_i) -
172-
\frac {1} {Total^2} \sum\limits_{i=1}^{K} N(F_i)N(O_i) }
178+
HK = \frac {\frac {1} {Total} \sum\limits_{i=1}^{K} n(F_i,O_i) -
179+
\frac {1} {Total^2} \sum\limits_{i=1}^{K} N(F_i)N(O_i) }
173180
{1 - \frac {1} {Total^2} \sum\limits_{i=1}^{K} N(O_i)^2}
174181
175182
Args:
176183
observation (Union[list, np.ndarray]): Multiclass observation data array
177184
that consist of class labels.
178-
forecast (Union[list, np.ndarray]): Multiclass forecast data array
185+
forecast (Union[list, np.ndarray]): Multiclass forecast data array
179186
that consist of class labels.
180187
181188
Returns:
182189
float: HK score. Perfect score 1.
183190
"""
184191

185-
confusion_matrix = calc_multiclass_confusion_matrix(
186-
observation, forecast
187-
)
192+
confusion_matrix = calc_multiclass_confusion_matrix(observation, forecast)
188193
total = len(observation)
189194

190195
# compute HK score
191-
acc_ = np.sum(confusion_matrix.values.diagonal()) / total
192-
reference_acc_ = np.sum(confusion_matrix.sum(axis=0).values * confusion_matrix.sum(axis=1).values) / (total**2)
196+
acc_ = np.sum(confusion_matrix.values.diagonal()) / total
197+
reference_acc_ = np.sum(
198+
confusion_matrix.sum(axis=0).values * confusion_matrix.sum(axis=1).values
199+
) / (total**2)
193200
perfect_acc_ = 1
194-
unbias_reference_acc_ = np.sum(confusion_matrix.sum(axis=0).values**2) / (total**2)
195-
hk_score_ = ( acc_ - reference_acc_ ) / (perfect_acc_ - unbias_reference_acc_)
196-
201+
unbias_reference_acc_ = np.sum(confusion_matrix.sum(axis=0).values ** 2) / (
202+
total**2
203+
)
204+
hk_score_ = (acc_ - reference_acc_) / (perfect_acc_ - unbias_reference_acc_)
205+
197206
return hk_score_
198207

208+
199209
@assert_length
200210
@fix_zero_division
201211
@drop_nan

cyeva/core/temp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ def __init__(
2121
super().__init__(observation, forecast)
2222
self.kind = kind
2323
self.lev = lev
24-
self.observation = (self.observation * UNITS.parse_expression(unit)).to("degC").magnitude
25-
self.forecast = (self.forecast * UNITS.parse_expression(unit)).to("degC").magnitude
24+
self.observation = (
25+
(self.observation * UNITS.parse_expression(unit)).to("degC").magnitude
26+
)
27+
self.forecast = (
28+
(self.forecast * UNITS.parse_expression(unit)).to("degC").magnitude
29+
)
2630
self.df = pd.DataFrame(
2731
{
2832
"observation": self.observation,

cyeva/core/weather_code.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def gather_all_factors(self):
4848
}
4949
)
5050

51-
5251
df = pd.DataFrame(result)
5352

5453
return df

cyeva/utils/decorators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def convert_to_ndarray(func):
3535

3636
@wraps(func)
3737
def wrapper(observation, forecast, *args, **kwargs):
38-
3938
if not isinstance(observation, np.ndarray) and not isinstance(
4039
observation, Number
4140
):

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@
6161
# so a file named "default.css" will overwrite the builtin "default.css".
6262
html_static_path = ["_static"]
6363

64-
master_doc = 'index'
64+
master_doc = "index"

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def get_version(rel_path):
3838
url="https://github.com/caiyunapp/cyeva",
3939
include_package_data=True,
4040
package_data={"": ["*.csv", "*.config", "*.nl", "*.json"]},
41-
packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
41+
packages=setuptools.find_packages(
42+
exclude=["*.tests", "*.tests.*", "tests.*", "tests"]
43+
),
4244
install_requires=required,
4345
classifiers=[
4446
"Development Status :: 4 - Beta",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .accuracy_ratio import ACCURACY_RATE_CASE
22
from .hk import HK_CASE
3-
from .hss import HSS_CASE
3+
from .hss import HSS_CASE
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
ACCURACY_RATE_CASE = [
22
{"obs": [1, 2, 3, 4, 5], "fct": [1, 2, 3, 4, 5], "result": 100},
3-
{"obs": ['A', 'B', 'C', 'D', 'E'], "fct": ['A', 'B', 'C', 'D', 'E'], "result": 100},
4-
{"obs": [1]*5 + [2]*5 + [3]*5 + [4]*5 + [5]*5, "fct": [1, 2, 3, 4, 5]*5, "result": 20}
5-
6-
]
3+
{"obs": ["A", "B", "C", "D", "E"], "fct": ["A", "B", "C", "D", "E"], "result": 100},
4+
{
5+
"obs": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
6+
"fct": [1, 2, 3, 4, 5] * 5,
7+
"result": 20,
8+
},
9+
]
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
HK_CASE = [
22
{"obs": [1, 2, 3, 4, 5], "fct": [1, 2, 3, 4, 5], "result": 1},
3-
{"obs": [1]*5 + [2]*5 + [3]*5 + [4]*5 + [5]*5, "fct": [1, 2, 3, 4, 5]*5, "result": 0}
4-
]
3+
{
4+
"obs": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
5+
"fct": [1, 2, 3, 4, 5] * 5,
6+
"result": 0,
7+
},
8+
]
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
HSS_CASE = [
22
{"obs": [1, 2, 3, 4, 5], "fct": [1, 2, 3, 4, 5], "result": 1},
3-
{"obs": [1]*5 + [2]*5 + [3]*5 + [4]*5 + [5]*5, "fct": [1, 2, 3, 4, 5]*5, "result": 0}
4-
]
3+
{
4+
"obs": [1] * 5 + [2] * 5 + [3] * 5 + [4] * 5 + [5] * 5,
5+
"fct": [1, 2, 3, 4, 5] * 5,
6+
"result": 0,
7+
},
8+
]

tests/functions/test_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
def test_comparison():
17-
1817
for case in RMSE_CASE:
1918
obs = case["obs"]
2019
fcst = case["fct"]

0 commit comments

Comments
 (0)