|
2 | 2 | from decimal import Decimal |
3 | 3 |
|
4 | 4 |
|
| 5 | +class Base: |
| 6 | + true_alternative_index: int |
| 7 | + |
| 8 | + def __init__(self, true_alternative_index: int) -> None: |
| 9 | + self.true_alternative_index = true_alternative_index |
| 10 | + |
| 11 | + |
| 12 | +class Brier(Base): |
| 13 | + """ |
| 14 | + Calculates scores when the order of predictions does not matter. |
| 15 | + """ |
| 16 | + |
| 17 | + def calculate(self, prediction: "Prediction") -> Decimal: |
| 18 | + score = Decimal("0.00") |
| 19 | + for index, probability in enumerate(prediction.probabilities): |
| 20 | + first_term = probability / Decimal(100) |
| 21 | + second_term = Decimal( |
| 22 | + "1.00" if index == self.true_alternative_index else "0.00" |
| 23 | + ) |
| 24 | + score = score + (first_term - second_term) ** 2 |
| 25 | + return score |
| 26 | + |
| 27 | + |
| 28 | +class OrderedCategorical(Base): |
| 29 | + """ |
| 30 | + Calculates scores when the order of predictions matters. |
| 31 | + """ |
| 32 | + |
| 33 | + def calculate(self, prediction: "Prediction") -> Decimal: |
| 34 | + total = Decimal("0.00") |
| 35 | + pair_count = self._pair_count(prediction.probabilities) |
| 36 | + for index in range(pair_count): |
| 37 | + pair = Prediction( |
| 38 | + self._split_probabilities(index, prediction.probabilities), |
| 39 | + true_alternative_index=1, |
| 40 | + ) |
| 41 | + score = self._score_pair(index, pair) |
| 42 | + total += score |
| 43 | + return self._average(total, pair_count) |
| 44 | + |
| 45 | + @staticmethod |
| 46 | + def _pair_count(probabilities: typing.Tuple[Decimal, ...]) -> int: |
| 47 | + """ |
| 48 | + We need one fewer pairs than the number of alternatives. For example, if there are three alternatives — A, B and C, the pairs are: |
| 49 | +
|
| 50 | + - A and BC |
| 51 | + - AB and C |
| 52 | + """ |
| 53 | + return len(probabilities) - 1 |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def _average(total: Decimal, count: int) -> Decimal: |
| 57 | + return total / Decimal(count) |
| 58 | + |
| 59 | + def _score_pair(self, index: int, pair: "Prediction") -> Decimal: |
| 60 | + assert len(pair.probabilities) == 2, "There must be exactly two probabilities." |
| 61 | + true_alternative_index = 0 if index > self.true_alternative_index else 1 |
| 62 | + brier_calculator = Brier(true_alternative_index=true_alternative_index) |
| 63 | + return brier_calculator.calculate(pair) |
| 64 | + |
| 65 | + @staticmethod |
| 66 | + def _split_probabilities( |
| 67 | + index: int, probabilities: typing.Tuple[Decimal, ...] |
| 68 | + ) -> typing.Tuple[Decimal, Decimal]: |
| 69 | + """ |
| 70 | + Given an index and a tuple of more than two probabilities, return a pair of grouped probabilities. |
| 71 | + """ |
| 72 | + assert len(probabilities) > 2 |
| 73 | + first_part = probabilities[: (index + 1)] |
| 74 | + second_part = probabilities[(index + 1) :] |
| 75 | + sum_first_part = Decimal(sum(first_part)) |
| 76 | + sum_second_part = Decimal(sum(second_part)) |
| 77 | + return sum_first_part, sum_second_part |
| 78 | + |
| 79 | + |
5 | 80 | class Prediction: |
6 | 81 | """ |
7 | 82 | This class encapsulates probabilities for a given question. |
8 | 83 | """ |
9 | 84 |
|
| 85 | + _cached_brier_score: typing.Optional[Decimal] = None |
| 86 | + order_matters: bool |
10 | 87 | probabilities: typing.Tuple[Decimal, ...] |
| 88 | + true_alternative_index: int |
11 | 89 |
|
12 | | - def __init__(self, probabilities: typing.Tuple[Decimal, ...]) -> None: |
| 90 | + def __init__( |
| 91 | + self, |
| 92 | + probabilities: typing.Tuple[Decimal, ...], |
| 93 | + true_alternative_index: int, |
| 94 | + order_matters: bool = False, |
| 95 | + ) -> None: |
13 | 96 | """ |
14 | 97 | 2 or more probabilities are required. Make sure that they sum to 100. |
15 | 98 | """ |
16 | | - assert len(probabilities) >= 2, "A prediction needs at least two probabilities." |
17 | 99 | assert sum(probabilities) == 100, "Probabilities need to sum to 100." |
| 100 | + length = len(probabilities) |
| 101 | + assert length >= 2, "A prediction needs at least two probabilities." |
| 102 | + assert ( |
| 103 | + true_alternative_index >= 0 |
| 104 | + ), "The true alternative index cannot be negative" |
| 105 | + assert ( |
| 106 | + true_alternative_index <= length - 1 |
| 107 | + ), "Probabilities need to contain the true alternative" |
| 108 | + self.order_matters = order_matters |
18 | 109 | self.probabilities = probabilities |
| 110 | + self.true_alternative_index = true_alternative_index |
| 111 | + |
| 112 | + @property |
| 113 | + def brier_score(self) -> Decimal: |
| 114 | + if isinstance(self._cached_brier_score, Decimal): |
| 115 | + return self._cached_brier_score |
| 116 | + calculator: typing.Union[Brier, OrderedCategorical] = ( |
| 117 | + OrderedCategorical(true_alternative_index=self.true_alternative_index) |
| 118 | + if self.order_matters |
| 119 | + else Brier(true_alternative_index=self.true_alternative_index) |
| 120 | + ) |
| 121 | + score = calculator.calculate(self) |
| 122 | + self._cached_brier_score = score |
| 123 | + return score |
0 commit comments