|
20 | 20 | import pytorch_lightning as pl |
21 | 21 | import torch # pylint: disable=unused-import |
22 | 22 | from anomalib.models.components import AnomalyModule |
| 23 | +from pydantic import BaseModel |
23 | 24 | from pytorch_lightning import Callback |
24 | 25 | from pytorch_lightning.utilities.types import STEP_OUTPUT |
25 | 26 |
|
26 | 27 | # https://github.com/python/cpython/issues/90015#issuecomment-1172996118 |
27 | 28 | MapOrValue: TypeAlias = "float | torch.Tensor | np.ndarray" |
28 | 29 |
|
29 | 30 |
|
30 | | -def normalize_anomaly_score(raw_score: MapOrValue, threshold: float) -> MapOrValue: |
31 | | - """Normalize anomaly score value or map based on threshold. |
| 31 | +class EvalThreshold(BaseModel): |
| 32 | + """Pair of raw and normalized threshold values used for consistency enforcement. |
| 33 | +
|
| 34 | + Attributes: |
| 35 | + raw: The unnormalized threshold. |
| 36 | + normalized: The corresponding normalized threshold. |
| 37 | + """ |
| 38 | + |
| 39 | + raw: float |
| 40 | + normalized: float |
| 41 | + |
| 42 | + |
| 43 | +def ensure_scores_consistency( |
| 44 | + normalized_score: MapOrValue, |
| 45 | + raw_score: MapOrValue, |
| 46 | + eval_threshold: EvalThreshold, |
| 47 | +) -> MapOrValue: |
| 48 | + """Enforce that the classification based on normalized scores matches the raw classification. |
| 49 | +
|
| 50 | + For every sample, if `raw_score >= eval_threshold.raw` (anomaly), the normalized score is |
| 51 | + clipped to be at least `eval_threshold.normalized`. If `raw_score < eval_threshold.raw` |
| 52 | + (normal), the normalized score is clipped to be strictly below `eval_threshold.normalized` |
| 53 | + using `np.nextafter` so that no hard-coded epsilon is required. |
32 | 54 |
|
33 | 55 | Args: |
34 | | - raw_score: Raw anomaly score valure or map |
35 | | - threshold: Threshold for anomaly detection |
| 56 | + normalized_score: Normalized anomaly score value or map to adjust. |
| 57 | + raw_score: Original (unnormalized) anomaly score used to determine the ground-truth |
| 58 | + classification for each sample. |
| 59 | + eval_threshold: Threshold pair defining the decision boundary in both spaces. |
36 | 60 |
|
37 | 61 | Returns: |
38 | | - Normalized anomaly score value or map clipped between 0 and 1000 |
| 62 | + Normalized score with consistent predictions. |
39 | 63 | """ |
40 | | - if threshold > 0: |
41 | | - normalized_score = (raw_score / threshold) * 100.0 |
42 | | - elif threshold == 0: |
43 | | - # TODO: Is this the best way to handle this case? |
44 | | - normalized_score = (raw_score + 1) * 100.0 |
45 | | - else: |
46 | | - normalized_score = 200.0 - ((raw_score / threshold) * 100.0) |
47 | | - |
48 | | - # Ensures that the normalized scores are consistent with the raw scores |
49 | | - # For all the items whose prediction changes after normalization, force the normalized score to be |
50 | | - # consistent with the prediction made on the raw score by clipping the score: |
51 | | - # - to 100.0 if the prediction was "anomaly" on the raw score and "good" on the normalized score |
52 | | - # - to 99.99 if the prediction was "good" on the raw score and "anomaly" on the normalized score |
53 | 64 | score = raw_score |
54 | 65 | if isinstance(score, torch.Tensor): |
55 | 66 | score = score.cpu().numpy() |
56 | | - # Anomalib classify as anomaly if anomaly_score gte threshold |
57 | | - is_anomaly_mask = score >= threshold |
| 67 | + |
| 68 | + boundary = eval_threshold.normalized |
| 69 | + is_anomaly_mask = score >= eval_threshold.raw |
58 | 70 | is_not_anomaly_mask = np.bitwise_not(is_anomaly_mask) |
| 71 | + |
| 72 | + _inf: torch.Tensor | np.ndarray |
| 73 | + below_boundary: torch.Tensor | np.ndarray |
| 74 | + anomaly_boundary: torch.Tensor | np.ndarray |
| 75 | + epsilon = 1e-3 |
59 | 76 | if isinstance(normalized_score, torch.Tensor): |
| 77 | + device = normalized_score.device |
| 78 | + # Work in scores dtype, cast boundaries to the same dype to ensure that casts take effect |
| 79 | + _inf = torch.tensor(float("inf"), dtype=normalized_score.dtype, device=device) |
| 80 | + boundary_tensor = torch.tensor(boundary, dtype=normalized_score.dtype, device=device) |
| 81 | + anomaly_boundary = boundary_tensor.clone() |
| 82 | + # If dtype cast causes anomaly_boundary to be smaller than normalized boundary (float), |
| 83 | + # increase it up to the next representable value |
| 84 | + if float(anomaly_boundary) < boundary: |
| 85 | + anomaly_boundary = torch.nextafter(anomaly_boundary, _inf) |
| 86 | + # Ensure consistency after rouding to 3 decimal places |
| 87 | + below_boundary = torch.min(torch.nextafter(boundary_tensor, -_inf), boundary_tensor - epsilon) |
| 88 | + |
60 | 89 | if normalized_score.dim() == 0: |
61 | 90 | normalized_score = ( |
62 | | - normalized_score.clamp(min=100.0) if is_anomaly_mask else normalized_score.clamp(max=99.99) |
| 91 | + normalized_score.clamp(min=anomaly_boundary) |
| 92 | + if is_anomaly_mask |
| 93 | + else normalized_score.clamp(max=below_boundary) |
63 | 94 | ) |
64 | 95 | else: |
65 | | - normalized_score[is_anomaly_mask] = normalized_score[is_anomaly_mask].clamp(min=100.0) |
66 | | - normalized_score[is_not_anomaly_mask] = normalized_score[is_not_anomaly_mask].clamp(max=99.99) |
| 96 | + normalized_score[is_anomaly_mask] = normalized_score[is_anomaly_mask].clamp(min=anomaly_boundary) |
| 97 | + normalized_score[is_not_anomaly_mask] = normalized_score[is_not_anomaly_mask].clamp(max=below_boundary) |
67 | 98 | elif isinstance(normalized_score, np.ndarray) or np.isscalar(normalized_score): |
| 99 | + # Work in scores dtype, cast boundaries to the same dype to ensure that casts take effect |
| 100 | + dtype = normalized_score.dtype if isinstance(normalized_score, np.ndarray) else np.float64 |
| 101 | + _inf = np.array(np.inf, dtype=dtype) |
| 102 | + boundary_array = np.array(boundary, dtype=dtype) |
| 103 | + anomaly_boundary = boundary_array.copy() |
| 104 | + # If dtype cast causes anomaly_boundary to be smaller than normalized boundary (float), |
| 105 | + # increase it up to the next representable value |
| 106 | + if float(anomaly_boundary) < boundary: |
| 107 | + anomaly_boundary = np.nextafter(anomaly_boundary, _inf) |
| 108 | + # Ensure consistency after rouding to 3 decimal places |
| 109 | + below_boundary = np.minimum(np.nextafter(boundary_array, -_inf), boundary_array - epsilon) |
| 110 | + |
68 | 111 | if np.isscalar(normalized_score) or normalized_score.ndim == 0: # type: ignore[union-attr] |
69 | 112 | normalized_score = ( |
70 | | - np.clip(normalized_score, a_min=100.0, a_max=None) |
| 113 | + np.clip(normalized_score, a_min=anomaly_boundary, a_max=None) |
71 | 114 | if is_anomaly_mask |
72 | | - else np.clip(normalized_score, a_min=None, a_max=99.99) |
| 115 | + else np.clip(normalized_score, a_min=None, a_max=below_boundary) |
73 | 116 | ) |
74 | 117 | else: |
75 | 118 | normalized_score = cast(np.ndarray, normalized_score) |
76 | | - normalized_score[is_anomaly_mask] = np.clip(normalized_score[is_anomaly_mask], a_min=100.0, a_max=None) |
| 119 | + normalized_score[is_anomaly_mask] = np.clip( |
| 120 | + normalized_score[is_anomaly_mask], a_min=anomaly_boundary, a_max=None |
| 121 | + ) |
77 | 122 | normalized_score[is_not_anomaly_mask] = np.clip( |
78 | | - normalized_score[is_not_anomaly_mask], a_min=None, a_max=99.99 |
| 123 | + normalized_score[is_not_anomaly_mask], a_min=None, a_max=below_boundary |
79 | 124 | ) |
80 | 125 |
|
| 126 | + return normalized_score |
| 127 | + |
| 128 | + |
| 129 | +def normalize_anomaly_score( |
| 130 | + raw_score: MapOrValue, |
| 131 | + threshold: float, |
| 132 | + eval_threshold: EvalThreshold | None = None, |
| 133 | +) -> MapOrValue: |
| 134 | + """Normalize anomaly score value or map based on threshold. |
| 135 | +
|
| 136 | + The training threshold maps to 100.0 in normalized space. After the linear scaling, |
| 137 | + `ensure_scores_consistency` is called to guarantee that every sample's normalized |
| 138 | + classification matches its raw classification. |
| 139 | +
|
| 140 | + Args: |
| 141 | + raw_score: Raw anomaly score value or map. |
| 142 | + threshold: Threshold for anomaly detection, usually it is the training threshold. |
| 143 | + eval_threshold: Threshold used during evaluation. It is used for ensure consistency of raw scores |
| 144 | + and normalized scores. When `None`, an `EvalThreshold` with `raw=threshold` and `normalized=100.0` is used, |
| 145 | + which reproduces the original behaviour for the training-threshold case. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + Normalized anomaly score value or map clipped between 0 and 1000 |
| 149 | + """ |
| 150 | + if threshold > 0: |
| 151 | + normalized_score = (raw_score / threshold) * 100.0 |
| 152 | + elif threshold == 0: |
| 153 | + # TODO: Is this the best way to handle this case? |
| 154 | + normalized_score = (raw_score + 1) * 100.0 |
| 155 | + else: |
| 156 | + normalized_score = 200.0 - ((raw_score / threshold) * 100.0) |
| 157 | + |
| 158 | + _eval_threshold = eval_threshold if eval_threshold is not None else EvalThreshold(raw=threshold, normalized=100.0) |
| 159 | + normalized_score = ensure_scores_consistency(normalized_score, raw_score, _eval_threshold) |
| 160 | + |
81 | 161 | if isinstance(normalized_score, torch.Tensor): |
82 | 162 | return torch.clamp(normalized_score, 0.0, 1000.0) |
83 | 163 |
|
|
0 commit comments