Skip to content

Commit 08d74a4

Browse files
JakeRaskindvoorhs
andauthored
add_logit_adaptivness (#23)
Co-authored-by: voorhs <[email protected]>
1 parent de779c0 commit 08d74a4

File tree

20 files changed

+660
-134
lines changed

20 files changed

+660
-134
lines changed

autointent/datafiles/default-multilabel-config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ nodes:
55
search_space:
66
- module_type: vector_db
77
k: [10]
8-
model_name:
8+
embedder_name:
99
- deepvk/USER-bge-m3
1010
- node_type: scoring
1111
metric: scoring_roc_auc
@@ -18,4 +18,5 @@ nodes:
1818
metric: prediction_accuracy
1919
search_space:
2020
- module_type: threshold
21-
thresh: [0.5]
21+
thresh: [0.5]
22+
- module_type: adaptive

autointent/modules/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .base import Module
44
from .prediction import (
5+
AdaptivePredictor,
56
ArgmaxPredictor,
67
JinoosPredictor,
78
PredictionModule,
@@ -35,10 +36,13 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
3536
[ArgmaxPredictor, JinoosPredictor, ThresholdPredictor, TunablePredictor]
3637
)
3738

38-
PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = create_modules_dict([ThresholdPredictor, TunablePredictor])
39+
PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = create_modules_dict(
40+
[AdaptivePredictor, ThresholdPredictor, TunablePredictor]
41+
)
3942

4043
__all__ = [
4144
"Module",
45+
"AdaptivePredictor",
4246
"ArgmaxPredictor",
4347
"JinoosPredictor",
4448
"PredictionModule",
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
from .adaptive import AdaptivePredictor
12
from .argmax import ArgmaxPredictor
23
from .base import PredictionModule
34
from .jinoos import JinoosPredictor
45
from .threshold import ThresholdPredictor
56
from .tunable import TunablePredictor
67

7-
__all__ = ["ArgmaxPredictor", "JinoosPredictor", "PredictionModule", "ThresholdPredictor", "TunablePredictor"]
8+
__all__ = [
9+
"AdaptivePredictor",
10+
"ArgmaxPredictor",
11+
"JinoosPredictor",
12+
"PredictionModule",
13+
"ThresholdPredictor",
14+
"TunablePredictor",
15+
]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Any, TypedDict
4+
5+
import numpy as np
6+
import numpy.typing as npt
7+
from sklearn.metrics import f1_score
8+
from typing_extensions import Self
9+
10+
from autointent import Context
11+
from autointent.context.data_handler import Tag
12+
from autointent.custom_types import LabelType
13+
from autointent.metrics.converter import transform
14+
15+
from .base import PredictionModule
16+
from .utils import InvalidNumClassesError, WrongClassificationError, apply_tags
17+
18+
default_search_space = np.linspace(0, 1, num=10)
19+
20+
21+
class AdaptivePredictorDumpMetadata(TypedDict):
22+
r: float
23+
tags: list[Tag] | None
24+
n_classes: int
25+
26+
27+
class AdaptivePredictor(PredictionModule):
28+
metadata_dict_name = "metadata.json"
29+
n_classes: int
30+
_r: float
31+
tags: list[Tag] | None
32+
name = "adaptive"
33+
34+
def __init__(self, search_space: list[float] | None = None) -> None:
35+
self.search_space = search_space if search_space is not None else default_search_space
36+
37+
@classmethod
38+
def from_context(cls, context: Context, search_space: list[float] | None = None) -> Self:
39+
return cls(
40+
search_space=search_space,
41+
)
42+
43+
def fit(
44+
self,
45+
scores: npt.NDArray[Any],
46+
labels: list[LabelType],
47+
tags: list[Tag] | None = None,
48+
) -> None:
49+
self.tags = tags
50+
multilabel = isinstance(labels[0], list)
51+
if not multilabel:
52+
msg = """AdaptivePredictor is not designed to perform multiclass classification,
53+
consider using other predictor algorithms"""
54+
raise WrongClassificationError(msg)
55+
self.n_classes = (
56+
len(labels[0]) if multilabel and isinstance(labels[0], list) else len(set(labels).difference([-1]))
57+
)
58+
59+
metrics_list = []
60+
for r in self.search_space:
61+
y_pred = multilabel_predict(scores, r, self.tags)
62+
metric_value = multilabel_score(labels, y_pred)
63+
metrics_list.append(metric_value)
64+
65+
self._r = float(self.search_space[np.argmax(metrics_list)])
66+
67+
def predict(self, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
68+
if scores.shape[1] != self.n_classes:
69+
msg = "Provided scores number don't match with number of classes which predictor was trained on."
70+
raise InvalidNumClassesError(msg)
71+
return multilabel_predict(scores, self._r, self.tags)
72+
73+
def dump(self, path: str) -> None:
74+
dump_dir = Path(path)
75+
76+
metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes)
77+
78+
with (dump_dir / self.metadata_dict_name).open("w") as file:
79+
json.dump(metadata, file, indent=4)
80+
81+
def load(self, path: str) -> None:
82+
dump_dir = Path(path)
83+
84+
with (dump_dir / self.metadata_dict_name).open() as file:
85+
metadata: AdaptivePredictorDumpMetadata = json.load(file)
86+
87+
self._r = metadata["r"]
88+
self.n_classes = metadata["n_classes"]
89+
self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type, union-attr]
90+
self.metadata = metadata
91+
92+
93+
def get_adapted_threshes(r: float, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
94+
return r * np.max(scores, axis=1) + (1 - r) * np.min(scores, axis=1) # type: ignore[no-any-return]
95+
96+
97+
def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | None) -> npt.NDArray[Any]:
98+
thresh = get_adapted_threshes(r, scores)
99+
res = (scores >= thresh[:, None]).astype(int) # suspicious
100+
if tags:
101+
res = apply_tags(res, scores, tags)
102+
return res
103+
104+
105+
def multilabel_score(y_true: list[LabelType], y_pred: npt.NDArray[Any]) -> float:
106+
y_true_, y_pred_ = transform(y_true, y_pred)
107+
108+
return f1_score(y_pred_, y_true_, average="weighted") # type: ignore[no-any-return]

autointent/modules/prediction/argmax.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
from pathlib import Path
13
from typing import Any
24

35
import numpy as np
@@ -6,14 +8,19 @@
68

79
from autointent import Context
810
from autointent.context.data_handler import Tag
9-
from autointent.custom_types import LabelType
11+
from autointent.custom_types import BaseMetadataDict, LabelType
1012

1113
from .base import PredictionModule
14+
from .utils import InvalidNumClassesError, WrongClassificationError
15+
16+
17+
class ArgmaxPredictorDumpMetadata(BaseMetadataDict):
18+
n_classes: int
1219

1320

1421
class ArgmaxPredictor(PredictionModule):
15-
metadata = {} # noqa: RUF012
1622
name = "argmax"
23+
n_classes: int
1724

1825
def __init__(self) -> None:
1926
pass
@@ -28,13 +35,31 @@ def fit(
2835
labels: list[LabelType],
2936
tags: list[Tag] | None = None,
3037
) -> None:
31-
pass
38+
multilabel = isinstance(labels[0], list)
39+
if multilabel:
40+
msg = "ArgmaxPredictor is compatible with single-label classifiction only"
41+
raise WrongClassificationError(msg)
42+
self.n_classes = len(set(labels).difference([-1]))
3243

3344
def predict(self, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
45+
if scores.shape[1] != self.n_classes:
46+
msg = "Provided scores number don't match with number of classes which predictor was trained on."
47+
raise InvalidNumClassesError(msg)
3448
return np.argmax(scores, axis=1) # type: ignore[no-any-return]
3549

50+
def dump(self, path: str) -> None:
51+
self.metadata = ArgmaxPredictorDumpMetadata(n_classes=self.n_classes)
52+
53+
dump_dir = Path(path)
54+
55+
with (dump_dir / self.metadata_dict_name).open("w") as file:
56+
json.dump(self.metadata, file, indent=4)
57+
3658
def load(self, path: str) -> None:
37-
pass
59+
dump_dir = Path(path)
3860

39-
def dump(self, path: str) -> None:
40-
pass
61+
with (dump_dir / self.metadata_dict_name).open() as file:
62+
metadata: ArgmaxPredictorDumpMetadata = json.load(file)
63+
64+
self.n_classes = metadata["n_classes"]
65+
self.metadata = metadata

autointent/modules/prediction/base.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -58,40 +58,3 @@ def get_prediction_evaluation_data(
5858
return_scores = np.concatenate([scores, oos_scores])
5959

6060
return labels.tolist(), return_scores
61-
62-
63-
def apply_tags(labels: npt.NDArray[Any], scores: npt.NDArray[Any], tags: list[Tag]) -> npt.NDArray[Any]:
64-
"""
65-
this function is intended to be used with multilabel predictor
66-
67-
If some intent classes have common tag (i.e. they are mutually exclusive) \
68-
and were assigned to one sample, leave only that class that has the highest score.
69-
70-
Arguments
71-
---
72-
- `labels`: np.ndarray of shape (n_samples, n_classes) with binary labels
73-
- `scores`: np.ndarray of shape (n_samples, n_classes) with float values from 0..1
74-
- `tags`: list of Tags
75-
76-
Return
77-
---
78-
np.ndarray of shape (n_samples, n_classes) with binary labels
79-
"""
80-
81-
n_samples, _ = labels.shape
82-
res = np.copy(labels)
83-
84-
for i in range(n_samples):
85-
sample_labels = labels[i].astype(bool)
86-
sample_scores = scores[i]
87-
88-
for tag in tags:
89-
if any(sample_labels[idx] for idx in tag.intent_ids):
90-
# Find the index of the class with the highest score among the tagged indices
91-
max_score_index = max(tag.intent_ids, key=lambda idx: sample_scores[idx])
92-
# Set all other tagged indices to 0 in the res
93-
for idx in tag.intent_ids:
94-
if idx != max_score_index:
95-
res[i, idx] = 0
96-
97-
return res

autointent/modules/prediction/jinoos.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
from autointent.metrics.converter import transform
1313

1414
from .base import PredictionModule
15+
from .utils import InvalidNumClassesError, WrongClassificationError
1516

1617
default_search_space = np.linspace(0, 1, num=100)
1718

1819

1920
class JinoosPredictorDumpMetadata(BaseMetadataDict):
2021
thresh: float
22+
n_classes: int
2123

2224

2325
class JinoosPredictor(PredictionModule):
2426
thresh: float
2527
name = "jinoos"
28+
n_classes: int
2629

2730
def __init__(
2831
self,
@@ -45,6 +48,12 @@ def fit(
4548
"""
4649
TODO: use dev split instead of test split
4750
"""
51+
multilabel = isinstance(labels[0], list)
52+
if multilabel:
53+
msg = "JinoosPredictor is compatible with single-label classification only"
54+
raise WrongClassificationError(msg)
55+
self.n_classes = len(set(labels).difference([-1]))
56+
4857
pred_classes, best_scores = _predict(scores)
4958

5059
metrics_list: list[float] = []
@@ -56,11 +65,14 @@ def fit(
5665
self.thresh = float(self.search_space[np.argmax(metrics_list)])
5766

5867
def predict(self, scores: npt.NDArray[Any]) -> npt.NDArray[Any]:
68+
if scores.shape[1] != self.n_classes:
69+
msg = "Provided scores number don't match with number of classes which predictor was trained on."
70+
raise InvalidNumClassesError(msg)
5971
pred_classes, best_scores = _predict(scores)
6072
return _detect_oos(pred_classes, best_scores, self.thresh)
6173

6274
def dump(self, path: str) -> None:
63-
self.metadata = JinoosPredictorDumpMetadata(thresh=self.thresh)
75+
self.metadata = JinoosPredictorDumpMetadata(thresh=self.thresh, n_classes=self.n_classes)
6476

6577
dump_dir = Path(path)
6678

@@ -75,6 +87,7 @@ def load(self, path: str) -> None:
7587

7688
self.thresh = metadata["thresh"]
7789
self.metadata = metadata
90+
self.n_classes = metadata["n_classes"]
7891

7992

8093
def _predict(scores: npt.NDArray[np.float64]) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:

0 commit comments

Comments
 (0)