Skip to content

Commit 4ebb88b

Browse files
Merge pull request #14 from AstraZeneca/tests-restructure
Tests restructured for scorecards and added non groupped aggregations
2 parents 2a0f13e + 69f3e17 commit 4ebb88b

File tree

8 files changed

+30
-42
lines changed

8 files changed

+30
-42
lines changed

rexmex/dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ class DatasetReader(object):
1313
"""
1414

1515
def __init__(self):
16-
self.base_url = (
17-
"https://raw.githubusercontent.com/AstraZeneca/rexmex/main/dataset/"
18-
)
16+
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/rexmex/main/dataset/"
1917

2018
def read_dataset(self, dataset: str = "erdos_renyi_example"):
2119
"""

rexmex/metrics/classification.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,6 @@ def pr_auc_score(y_true: np.array, y_score: np.array) -> float:
563563
Returns:
564564
pr_auc (float): The value of the precision-recall area under the curve.
565565
"""
566-
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
567-
y_true, y_score
568-
)
566+
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)
569567
pr_auc = sklearn.metrics.auc(recall, precision)
570568
return pr_auc

rexmex/metrics/rating.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def root_mean_squared_error(y_true: np.array, y_score: np.array) -> float:
8787
return rmse
8888

8989

90-
def symmetric_mean_absolute_percentage_error(
91-
y_true: np.array, y_score: np.array
92-
) -> float:
90+
def symmetric_mean_absolute_percentage_error(y_true: np.array, y_score: np.array) -> float:
9391
"""
9492
Calculate the symmetric mean absolute percentage error (SMAPE) for a ground-truth prediction vector pair.
9593
@@ -99,7 +97,5 @@ def symmetric_mean_absolute_percentage_error(
9997
Returns:
10098
smape (float): The value of the symmetric mean absolute percentage error.
10199
"""
102-
smape = 100 * np.mean(
103-
np.abs(y_score - y_true) / ((np.abs(y_score) + np.abs(y_true)) / 2)
104-
)
100+
smape = 100 * np.mean(np.abs(y_score - y_true) / ((np.abs(y_score) + np.abs(y_true)) / 2))
105101
return smape

rexmex/metricset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import pandas as pd
2-
from typing import List, Dict, Tuple
1+
from typing import List, Tuple
32

43
from rexmex.utils import binarize, normalize
54

@@ -114,9 +113,7 @@ def __init__(self):
114113
self["pr_auc"] = pr_auc_score
115114
self["average_precision"] = average_precision_score
116115
self["f1_score"] = binarize(f1_score)
117-
self["matthews_correlation_coefficent"] = binarize(
118-
matthews_correlation_coefficient
119-
)
116+
self["matthews_correlation_coefficent"] = binarize(matthews_correlation_coefficient)
120117
self["fowlkes_mallows_index"] = binarize(fowlkes_mallows_index)
121118
self["precision"] = binarize(precision_score)
122119
self["recall"] = binarize(recall_score)

rexmex/scorecard.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@ class ScoreCard(object):
1010
"""
1111

1212
def __init__(self, metric_set: rexmex.metricset.MetricSet):
13-
self._metric_set = metric_set
13+
self.metric_set = metric_set
1414

15-
def _get_performance_metrics(
16-
self, y_true: np.array, y_score: np.array
17-
) -> pd.DataFrame:
15+
def _get_performance_metrics(self, y_true: np.array, y_score: np.array) -> pd.DataFrame:
1816
"""
1917
A method to get the performance metrics for a pair of vectors.
2018
@@ -24,15 +22,11 @@ def _get_performance_metrics(
2422
Returns:
2523
performance_metrics (pd.DataFrame): The performance metrics calculated from the vectors.
2624
"""
27-
performance_metrics = {
28-
name: [metric(y_true, y_score)] for name, metric in self._metric_set.items()
29-
}
25+
performance_metrics = {name: [metric(y_true, y_score)] for name, metric in self.metric_set.items()}
3026
performance_metrics = pd.DataFrame.from_dict(performance_metrics)
3127
return performance_metrics
3228

33-
def generate_report(
34-
self, scores_to_evaluate: pd.DataFrame, groupping: List[str] = None
35-
) -> pd.DataFrame:
29+
def generate_report(self, scores_to_evaluate: pd.DataFrame, groupping: List[str] = None) -> pd.DataFrame:
3630
"""
3731
A method to calculate (aggregated) performance metrics based
3832
on a dataframe of ground truth and predictions. It assumes that the dataframe has the `y_true`
@@ -47,13 +41,9 @@ def generate_report(
4741
"""
4842
if groupping is not None:
4943
scores_to_evaluate = scores_to_evaluate.groupby(groupping)
50-
report = scores_to_evaluate.apply(
51-
lambda group: self._get_performance_metrics(group.y_true, group.y_score)
52-
)
44+
report = scores_to_evaluate.apply(lambda group: self._get_performance_metrics(group.y_true, group.y_score))
5345
else:
54-
report = self._get_performance_metrics(
55-
scores_to_evaluate.y_true, scores_to_evaluate.y_score
56-
)
46+
report = self._get_performance_metrics(scores_to_evaluate.y_true, scores_to_evaluate.y_score)
5747
return report
5848

5949
def __repr__(self):
@@ -66,4 +56,4 @@ def print_metrics(self):
6656
"""
6757
Printing the name of metrics.
6858
"""
69-
print({k for k in self._metric_set.keys()})
59+
print({k for k in self.metric_set.keys()})

tests/integration/test_aggregation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def test_classification(self):
1515
metric_set = ClassificationMetricSet()
1616
score_card = ScoreCard(metric_set)
1717

18+
performance_metrics = score_card.generate_report(self.scores)
19+
assert performance_metrics.shape == (1, 11)
20+
1821
performance_metrics = score_card.generate_report(
1922
self.scores, groupping=["source_group"]
2023
)
@@ -30,6 +33,9 @@ def test_regression(self):
3033
metric_set.normalize_metrics()
3134
score_card = ScoreCard(metric_set)
3235

36+
performance_metrics = score_card.generate_report(self.scores)
37+
assert performance_metrics.shape == (1, 7)
38+
3339
performance_metrics = score_card.generate_report(
3440
self.scores, groupping=["source_group"]
3541
)
@@ -44,6 +50,9 @@ def test_addition(self):
4450
metric_set = RatingMetricSet() + ClassificationMetricSet()
4551
score_card = ScoreCard(metric_set)
4652

53+
performance_metrics = score_card.generate_report(self.scores)
54+
assert performance_metrics.shape == (1, 18)
55+
4756
performance_metrics = score_card.generate_report(
4857
self.scores, groupping=["source_group"]
4958
)

tests/unit/test_scorecard.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88

99

1010
class TestMetricSet(unittest.TestCase):
11+
def setUp(self):
12+
self.metric_set = ClassificationMetricSet()
13+
self.score_card = ScoreCard(self.metric_set)
14+
1115
def test_representation(self):
12-
metric_set = ClassificationMetricSet()
13-
score_card = ScoreCard(metric_set)
14-
assert repr(score_card) == "ScoreCard()"
16+
assert repr(self.score_card) == "ScoreCard()"
1517

1618
def test_printing(self):
17-
metric_set = ClassificationMetricSet()
18-
metric_set.filter_metrics(["roc_auc", "pr_auc"])
19-
score_card = ScoreCard(metric_set)
19+
self.score_card.metric_set.filter_metrics(["roc_auc", "pr_auc"])
2020
captured = StringIO()
2121
sys.stdout = captured
22-
score_card.print_metrics()
22+
self.score_card.print_metrics()
2323
sys.stdout = sys.__stdout__
2424
out = captured.getvalue().strip("\n")
2525
assert out == str({"roc_auc", "pr_auc"})

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ deps =
2121
flake8
2222
flake8-black
2323
commands =
24-
flake8 --select BLK100 rexmex/ tests/ setup.py
24+
flake8 --select BLK120 rexmex/ tests/ setup.py

0 commit comments

Comments
 (0)