Skip to content

Commit 66fecba

Browse files
Merge branch 'dev' of https://github.com/deeppavlov/AutoIntent into my-fix-branch
2 parents c23860f + 13d8c40 commit 66fecba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1480
-199
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,5 @@ vector_db*
182182
*.db
183183
*.sqlite
184184
/wandb
185+
model_output/
186+
my.py

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,10 @@
88
"*.yaml",
99
"!*/.github/*/*.yaml"
1010
]
11-
}
11+
},
12+
"python.testing.pytestArgs": [
13+
"."
14+
],
15+
"python.testing.unittestEnabled": false,
16+
"python.testing.pytestEnabled": true
1217
}

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Auto ML for intent classification.
77

88
Documentation: [deeppavlov.github.io/AutoIntent](https://deeppavlov.github.io/AutoIntent/).
99

10-
> The project is under active development.
10+
The project is under active development.
1111

1212
## Installation
1313

@@ -35,6 +35,21 @@ pipeline.fit(dataset)
3535
pipeline.predict(["show me my latest transactions"])
3636
```
3737

38+
## Cite
39+
40+
If you find our work useful, please cite our EMNLP 2025 [paper](https://arxiv.org/abs/2509.21138):
41+
```
42+
@misc{alekseev2025autointentautomltextclassification,
43+
title={AutoIntent: AutoML for Text Classification},
44+
author={Ilya Alekseev and Roman Solomatin and Darina Rustamova and Denis Kuznetsov},
45+
year={2025},
46+
eprint={2509.21138},
47+
archivePrefix={arXiv},
48+
primaryClass={cs.CL},
49+
url={https://arxiv.org/abs/2509.21138},
50+
}
51+
```
52+
3853
## Disclaimer
3954

4055
This project is in development phase. Bugs and breaking changes are expected. Contributions and feedback are welcome! See [CONTRIBUTING.md](./CONTRIBUTING.md).

docs/optimizer_config.schema.json

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,6 @@
266266
"description": "Whether to use embeddings caching.",
267267
"title": "Use Cache",
268268
"type": "boolean"
269-
},
270-
"freeze": {
271-
"default": true,
272-
"description": "Whether to freeze the model parameters.",
273-
"title": "Freeze",
274-
"type": "boolean"
275269
}
276270
},
277271
"title": "EmbedderConfig",
@@ -578,8 +572,7 @@
578572
"query_prompt": null,
579573
"passage_prompt": null,
580574
"similarity_fn_name": "cosine",
581-
"use_cache": true,
582-
"freeze": true
575+
"use_cache": true
583576
}
584577
},
585578
"cross_encoder_config": {

src/autointent/_dump_tools/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import numpy.typing as npt
7+
import torch
78

89
from autointent.configs import CrossEncoderConfig, EmbedderConfig
910
from autointent.context.optimization_info import Artifact
@@ -108,6 +109,8 @@ def dump(
108109
simple_attrs[key] = val
109110
elif isinstance(val, np.ndarray):
110111
arrays[key] = val
112+
elif isinstance(val, torch.Tensor):
113+
arrays[key] = val.cpu().numpy()
111114
else:
112115
# Use the appropriate dumper for complex objects
113116
Dumper._dump_single_object(key, val, path, exists_ok, raise_errors)

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from autointent import Embedder, Ranker, VectorIndex
24-
from autointent._wrappers import BaseTorchModuleWithVocab
24+
from autointent._wrappers import BaseTorchModule
2525
from autointent.schemas import TagsList
2626

2727
from .base import BaseObjectDumper, ModuleSimpleAttributes
@@ -276,11 +276,11 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
276276
return isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast)
277277

278278

279-
class TorchModelDumper(BaseObjectDumper[BaseTorchModuleWithVocab]):
279+
class TorchModelDumper(BaseObjectDumper[BaseTorchModule]):
280280
dir_or_file_name = "torch_models"
281281

282282
@staticmethod
283-
def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
283+
def dump(obj: BaseTorchModule, path: Path, exists_ok: bool) -> None:
284284
path.mkdir(parents=True, exist_ok=exists_ok)
285285
class_info = {
286286
"module": obj.__class__.__module__,
@@ -291,16 +291,16 @@ def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
291291
obj.dump(path)
292292

293293
@staticmethod
294-
def load(path: Path, **kwargs: Any) -> BaseTorchModuleWithVocab: # noqa: ANN401, ARG004
294+
def load(path: Path, **kwargs: Any) -> BaseTorchModule: # noqa: ANN401, ARG004
295295
with (path / "class_info.json").open("r") as f:
296296
class_info = json.load(f)
297297
module = importlib.import_module(class_info["module"])
298-
model_class: BaseTorchModuleWithVocab = getattr(module, class_info["name"])
298+
model_class: BaseTorchModule = getattr(module, class_info["name"])
299299
return model_class.load(path)
300300

301301
@classmethod
302302
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
303-
return isinstance(obj, BaseTorchModuleWithVocab)
303+
return isinstance(obj, BaseTorchModule)
304304

305305

306306
class CatBoostDumper(BaseObjectDumper[CatBoostClassifier]):

src/autointent/_wrappers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .embedder import Embedder
33
from .vector_index import VectorIndex
44
from .base_torch_module import BaseTorchModuleWithVocab
5+
from .base_torch_module import BaseTorchModule
56

6-
__all__ = ["BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]
7+
__all__ = ["BaseTorchModule", "BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]

src/autointent/_wrappers/base_torch_module.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,52 @@
1313
from autointent.configs import VocabConfig
1414

1515

16-
class BaseTorchModuleWithVocab(nn.Module, ABC):
16+
class BaseTorchModule(nn.Module, ABC):
17+
@abstractmethod
18+
def forward(self, text: torch.Tensor) -> torch.Tensor:
19+
"""Compute sentence embeddings for given text.
20+
21+
Args:
22+
text: torch tensor of shape (B, T), token ids
23+
24+
Returns:
25+
embeddings of shape (B, H)
26+
"""
27+
28+
@abstractmethod
29+
def dump(self, path: Path) -> None:
30+
"""Dump torch module to disk.
31+
32+
This method encapsulates all the logic of dumping module's weights and
33+
hyperparameters required for initialization from disk and nice inference.
34+
35+
Args:
36+
path: path in file system
37+
"""
38+
39+
@classmethod
40+
@abstractmethod
41+
def load(cls, path: Path, device: str | None = None) -> Self:
42+
"""Load torch module from disk.
43+
44+
This method loads all weights and hyperparameters required for
45+
initialization from disk and inference.
46+
47+
Args:
48+
path: path in file system
49+
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
50+
"""
51+
52+
@property
53+
def device(self) -> torch.device:
54+
"""Torch device object where this module resides."""
55+
return next(self.parameters()).device
56+
57+
58+
class BaseTorchModuleWithVocab(BaseTorchModule, ABC):
1759
def __init__(
1860
self,
19-
embed_dim: int,
61+
embed_dim: int | None = None,
2062
vocab_config: VocabConfig | None = None,
2163
) -> None:
2264
super().__init__()
@@ -34,6 +76,9 @@ def __init__(
3476

3577
def set_vocab(self, vocab: dict[str, Any]) -> None:
3678
"""Save vocabulary into module's attributes and initialize embeddings matrix."""
79+
if self.embed_dim is None:
80+
msg = "embed_dim must be set to initialize embeddings"
81+
raise ValueError(msg)
3782
self.vocab_config.vocab = vocab
3883
self.embedding = nn.Embedding(
3984
num_embeddings=len(self.vocab_config.vocab),
@@ -43,6 +88,10 @@ def set_vocab(self, vocab: dict[str, Any]) -> None:
4388

4489
def build_vocab(self, utterances: list[str]) -> None:
4590
"""Build vocabulary from training utterances."""
91+
if self.embed_dim is None:
92+
msg = "embed_dim must be set to initialize embeddings"
93+
raise ValueError(msg)
94+
4695
if self.vocab_config.vocab is not None:
4796
msg = "Vocab is already built."
4897
raise RuntimeError(msg)
@@ -80,43 +129,3 @@ def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
80129
seq = seq + [self.vocab_config.padding_idx] * (self.vocab_config.max_seq_length - len(seq))
81130
sequences.append(seq)
82131
return sequences
83-
84-
@abstractmethod
85-
def forward(self, text: torch.Tensor) -> torch.Tensor:
86-
"""Compute sentence embeddings for given text.
87-
88-
Args:
89-
text: torch tensor of shape (B, T), token ids
90-
91-
Returns:
92-
embeddings of shape (B, H)
93-
"""
94-
95-
@abstractmethod
96-
def dump(self, path: Path) -> None:
97-
"""Dump torch module to disk.
98-
99-
This method encapsulates all the logic of dumping module's weights and
100-
hyperparameters required for initialization from disk and nice inference.
101-
102-
Args:
103-
path: path in file system
104-
"""
105-
106-
@classmethod
107-
@abstractmethod
108-
def load(cls, path: Path, device: str | None = None) -> Self:
109-
"""Load torch module from disk.
110-
111-
This method loads all weights and hyperparameters required for
112-
initialization from disk and inference.
113-
114-
Args:
115-
path: path in file system
116-
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
117-
"""
118-
119-
@property
120-
def device(self) -> torch.device:
121-
"""Torch device object where this module resides."""
122-
return next(self.parameters()).device

0 commit comments

Comments
 (0)