Skip to content

Commit

Permalink
Merge pull request #14 from AstraZeneca/tests-restructure
Browse files Browse the repository at this point in the history
Tests restructured for scorecards and added non groupped aggregations
  • Loading branch information
benedekrozemberczki authored Dec 2, 2021
2 parents 2a0f13e + 69f3e17 commit 4ebb88b
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 42 deletions.
4 changes: 1 addition & 3 deletions rexmex/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ class DatasetReader(object):
"""

def __init__(self):
self.base_url = (
"https://raw.githubusercontent.com/AstraZeneca/rexmex/main/dataset/"
)
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/rexmex/main/dataset/"

def read_dataset(self, dataset: str = "erdos_renyi_example"):
"""
Expand Down
4 changes: 1 addition & 3 deletions rexmex/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,6 @@ def pr_auc_score(y_true: np.array, y_score: np.array) -> float:
Returns:
pr_auc (float): The value of the precision-recall area under the curve.
"""
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
y_true, y_score
)
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)
pr_auc = sklearn.metrics.auc(recall, precision)
return pr_auc
8 changes: 2 additions & 6 deletions rexmex/metrics/rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def root_mean_squared_error(y_true: np.array, y_score: np.array) -> float:
return rmse


def symmetric_mean_absolute_percentage_error(
y_true: np.array, y_score: np.array
) -> float:
def symmetric_mean_absolute_percentage_error(y_true: np.array, y_score: np.array) -> float:
"""
Calculate the symmetric mean absolute percentage error (SMAPE) for a ground-truth prediction vector pair.
Expand All @@ -99,7 +97,5 @@ def symmetric_mean_absolute_percentage_error(
Returns:
smape (float): The value of the symmetric mean absolute percentage error.
"""
smape = 100 * np.mean(
np.abs(y_score - y_true) / ((np.abs(y_score) + np.abs(y_true)) / 2)
)
smape = 100 * np.mean(np.abs(y_score - y_true) / ((np.abs(y_score) + np.abs(y_true)) / 2))
return smape
7 changes: 2 additions & 5 deletions rexmex/metricset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pandas as pd
from typing import List, Dict, Tuple
from typing import List, Tuple

from rexmex.utils import binarize, normalize

Expand Down Expand Up @@ -114,9 +113,7 @@ def __init__(self):
self["pr_auc"] = pr_auc_score
self["average_precision"] = average_precision_score
self["f1_score"] = binarize(f1_score)
self["matthews_correlation_coefficent"] = binarize(
matthews_correlation_coefficient
)
self["matthews_correlation_coefficent"] = binarize(matthews_correlation_coefficient)
self["fowlkes_mallows_index"] = binarize(fowlkes_mallows_index)
self["precision"] = binarize(precision_score)
self["recall"] = binarize(recall_score)
Expand Down
24 changes: 7 additions & 17 deletions rexmex/scorecard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ class ScoreCard(object):
"""

def __init__(self, metric_set: rexmex.metricset.MetricSet):
self._metric_set = metric_set
self.metric_set = metric_set

def _get_performance_metrics(
self, y_true: np.array, y_score: np.array
) -> pd.DataFrame:
def _get_performance_metrics(self, y_true: np.array, y_score: np.array) -> pd.DataFrame:
"""
A method to get the performance metrics for a pair of vectors.
Expand All @@ -24,15 +22,11 @@ def _get_performance_metrics(
Returns:
performance_metrics (pd.DataFrame): The performance metrics calculated from the vectors.
"""
performance_metrics = {
name: [metric(y_true, y_score)] for name, metric in self._metric_set.items()
}
performance_metrics = {name: [metric(y_true, y_score)] for name, metric in self.metric_set.items()}
performance_metrics = pd.DataFrame.from_dict(performance_metrics)
return performance_metrics

def generate_report(
self, scores_to_evaluate: pd.DataFrame, groupping: List[str] = None
) -> pd.DataFrame:
def generate_report(self, scores_to_evaluate: pd.DataFrame, groupping: List[str] = None) -> pd.DataFrame:
"""
A method to calculate (aggregated) performance metrics based
on a dataframe of ground truth and predictions. It assumes that the dataframe has the `y_true`
Expand All @@ -47,13 +41,9 @@ def generate_report(
"""
if groupping is not None:
scores_to_evaluate = scores_to_evaluate.groupby(groupping)
report = scores_to_evaluate.apply(
lambda group: self._get_performance_metrics(group.y_true, group.y_score)
)
report = scores_to_evaluate.apply(lambda group: self._get_performance_metrics(group.y_true, group.y_score))
else:
report = self._get_performance_metrics(
scores_to_evaluate.y_true, scores_to_evaluate.y_score
)
report = self._get_performance_metrics(scores_to_evaluate.y_true, scores_to_evaluate.y_score)
return report

def __repr__(self):
Expand All @@ -66,4 +56,4 @@ def print_metrics(self):
"""
Printing the name of metrics.
"""
print({k for k in self._metric_set.keys()})
print({k for k in self.metric_set.keys()})
9 changes: 9 additions & 0 deletions tests/integration/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def test_classification(self):
metric_set = ClassificationMetricSet()
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 11)

performance_metrics = score_card.generate_report(
self.scores, groupping=["source_group"]
)
Expand All @@ -30,6 +33,9 @@ def test_regression(self):
metric_set.normalize_metrics()
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 7)

performance_metrics = score_card.generate_report(
self.scores, groupping=["source_group"]
)
Expand All @@ -44,6 +50,9 @@ def test_addition(self):
metric_set = RatingMetricSet() + ClassificationMetricSet()
score_card = ScoreCard(metric_set)

performance_metrics = score_card.generate_report(self.scores)
assert performance_metrics.shape == (1, 18)

performance_metrics = score_card.generate_report(
self.scores, groupping=["source_group"]
)
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/test_scorecard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@


class TestMetricSet(unittest.TestCase):
def setUp(self):
self.metric_set = ClassificationMetricSet()
self.score_card = ScoreCard(self.metric_set)

def test_representation(self):
metric_set = ClassificationMetricSet()
score_card = ScoreCard(metric_set)
assert repr(score_card) == "ScoreCard()"
assert repr(self.score_card) == "ScoreCard()"

def test_printing(self):
metric_set = ClassificationMetricSet()
metric_set.filter_metrics(["roc_auc", "pr_auc"])
score_card = ScoreCard(metric_set)
self.score_card.metric_set.filter_metrics(["roc_auc", "pr_auc"])
captured = StringIO()
sys.stdout = captured
score_card.print_metrics()
self.score_card.print_metrics()
sys.stdout = sys.__stdout__
out = captured.getvalue().strip("\n")
assert out == str({"roc_auc", "pr_auc"})
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ deps =
flake8
flake8-black
commands =
flake8 --select BLK100 rexmex/ tests/ setup.py
flake8 --select BLK120 rexmex/ tests/ setup.py

0 comments on commit 4ebb88b

Please sign in to comment.