Skip to content

Commit 8a62247

Browse files
committed
Model training docs update
1 parent 8b40c48 commit 8a62247

File tree

1 file changed

+94
-14
lines changed

1 file changed

+94
-14
lines changed

docs/docs/dev/model_training.md

+94-14
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,30 @@ While the user is providing labels for a category, Label Sleuth iteratively trai
88

99
## Training invocation
1010

11-
In order to cater to domain experts, model training does not have to be explicitly invoked by the user. Instead Label Sleuth automatically invokes model training in the background when certain conditions are met. To ensure that the user can see the the most up to date model predictions and received appropriate active learning guidance, model training is repeated with Label Sleuth training new models (which can be thought of as new versions of the classifier) as the user continues labeling.
11+
In order to cater to domain experts, model training does not have to be explicitly invoked by the user. Instead Label Sleuth automatically invokes model training in the background when certain conditions are met. To ensure that the user can see the most up to date model predictions and receive appropriate active learning guidance, Label Sleuth trains new models (which can be thought of as new versions of the classifier) as the user continues labeling.
1212

1313
### Training invocation criteria
1414

15-
Label Sleuth starts a new model training iteration whenever the following two conditions are both met:
15+
Label Sleuth starts a new model training iteration whenever the following conditions are met:
1616

17+
::::{tab-set}
18+
:::{tab-item} Binary mode
1719
| Condition on | Description | Default |
1820
|---|---|---|
19-
| **Number of positive labels** | The user has to provide a minimum number of positive labels. The threshold of required positive labels can be configured by setting the value of the `first_model_positive_threshold` parameter in the system's [configuration file](configuration.md). | 20 |
20-
| **Number of negative labels** | The user has to provide a minimum number of negative labels. The threshold of required negative labels can be configured by setting the value of the `first_model_negative_threshold` parameter in the system's [configuration file](configuration.md). | 0 |
21-
| **Number of label changes** | The user has to change a minimum number of labels since the last model training iteration (unless it is the first iteration). A change can be assigning a label (positive or negative) to an element, or changing an existing label. The threshold of required label changes can be configured by setting the value of the `changed_element_threshold` parameter in the system's [configuration file](configuration.md). | 20 |
21+
| **Number of positive labels** | The user has to provide a minimum number of positive labels. The threshold of required positive labels can be configured by setting the value of the `binary_flow.first_model_positive_threshold` parameter in the system's [configuration file](configuration.md). | 20 |
22+
| **Number of negative labels** | The user has to provide a minimum number of negative labels. The threshold of required negative labels can be configured by setting the value of the `binary_flow.first_model_negative_threshold` parameter in the system's [configuration file](configuration.md). | 0 |
23+
| **Number of label changes** | The user has to change a minimum number of labels since the last model training iteration (unless it is the first iteration). A change can be assigning a label (positive or negative) to an element, or changing an existing label. The threshold of required label changes can be configured by setting the value of the `binary_flow.changed_element_threshold` parameter in the system's [configuration file](configuration.md). | 20 |
24+
:::
25+
:::{tab-item} Multiclass mode
26+
27+
| Condition on | Description | Default |
28+
|---|---|---|
29+
| **Number of labels per categorys** | The user has to provide a minimum number of labels per category. The threshold of required labels per category can be configured by setting the value of the `multiclass_flow.per_class_labeling_threshold` parameter in the system's [configuration file](configuration.md). | 5 |
30+
| **Number of label changes** | The user has to change a minimum number of labels since the last model training iteration (unless it is the first iteration). A change can be assigning a label to an element, or changing an existing label. The threshold of required label changes can be configured by setting the value of the `multiclass_flow.changed_element_threshold` parameter in the system's [configuration file](configuration.md). | 20 |
31+
| **Zero shot first interation** | If the `multiclass_flow.zero_shot_first_model` parameter in the system's [configuration file](configuration.md) is True, a zero shot model will be used to make predictions on the entire dataset after the categories are created. | False |
32+
| **Category list changes** | If new categories are created or existing categories are edited or deleted and there is already a model available, a new model training iteration will be triggered. | - |
33+
:::
34+
::::
2235

2336
## Training set selection
2437

@@ -28,16 +41,38 @@ When the training invocation criteria are satisfied, Label Sleuth selects the ex
2841

2942
Label Sleuth currently supports the following training set selection strategies:
3043

44+
::::{tab-set}
45+
:::{tab-item} Binary mode
46+
47+
The employed training set selection strategy can be configured by setting the value of the `binary_flow.training_set_selection_strategy` parameter in the system's [configuration file](configuration.md). Note that in each case, Label Sleuth removes duplicates so that only unique elements are included in the training set.
48+
3149
| Training set selection strategy | Description |
3250
|---|---|
33-
| `ALL_LABELED` | Use examples labeled by the user (without any modification). |
51+
| `ALL_LABELED` | Use all the examples labeled by the user (without any modification). |
52+
| `ALL_LABELED_NO_VERIFICATION` | Use examples labeled by the user (without any modification) without verifying that both positive and negative labels are present. |
3453
| `ALL_LABELED_PLUS_UNLABELED_AS_NEGATIVE_EQUAL_RATIO` | Ensure a ratio of _1 negative example for every positive example_. See below for details of how this ratio is ensured. |
3554
| `ALL_LABELED_PLUS_UNLABELED_AS_NEGATIVE_X2_RATIO` <br /><defvalue>default</defvalue> | Ensure a ratio of _2 negative example for every positive example_. See below for details of how this ratio is ensured. |
3655
| `ALL_LABELED_PLUS_UNLABELED_AS_NEGATIVE_X10_RATIO` | Ensure a ratio of _10 negative example for every positive example_. See below for details of how this ratio is ensured. |
56+
| `ALL_LABELED_INCLUDE_WEAK` | Use examples labeled by the user and include weak labels. |
57+
| `ALL_LABELED_INCLUDE_WEAK_PLUS_UNLABELED_AS_NEGATIVE_EQUAL_RATIO` | Ensure a ratio of _1 negative example for every positive example_ and Include weak labels. See below for details of how this ratio is ensured and how are weak labels included. |
58+
| `ALL_LABELED_INCLUDE_WEAK_PLUS_UNLABELED_AS_NEGATIVE_X2_RATIO` | Ensure a ratio of _2 negative example for every positive example_. Include weak labels. See below for details of how this ratio is ensured and how are weak labels included. |
59+
| `ALL_LABELED_INCLUDE_WEAK_PLUS_UNLABELED_AS_NEGATIVE_X10_RATIO` | Ensure a ratio of _10 negative example for every positive example_. Include weak labels. See below for details of how this ratio is ensured and how are weak labels included. |
3760

3861
If one of the training set selections strategies specifying a ratio of negative to positive examples is chosen, Label Sleuth ensures the respective ratio as follows: If the user has labeled fewer negative examples than the ratio, some _unlabeled_ examples are automatically added to the training set as negative examples. On the other hand, if the number of negative examples labeled by the user exceeds the ratio, only a sample of the user-labeled negative examples are included in the training set.
3962

40-
The employed training set selection strategy can be configured by setting the value of the `training_set_selection_strategy` parameter in the system's [configuration file](configuration.md). Note that in each case, Label Sleuth removes duplicates so that only unique elements are included in the training set.
63+
If one of the training selection strategies specifying that weak labels should be included is chosen, Label Sleuth will include all the weak labels available to the training data. Currently, the only way of including weak labels is uploading them to the workspace setting `label_type` to `Weak``.
64+
65+
:::
66+
:::{tab-item} Multiclass mode
67+
68+
The employed training set selection strategy can be configured by setting the value of the `multiclass_flow.training_set_selection_strategy` parameter in the system's [configuration file](configuration.md). Note that in each case, Label Sleuth removes duplicates so that only unique elements are included in the training set.
69+
70+
| Training set selection strategy | Description |
71+
|---|---|
72+
| `ALL_LABELED_MULTICLASS` | Use all the examples labeled by the user. |
73+
74+
:::
75+
::::
4176

4277
## Model selection
4378

@@ -47,15 +82,31 @@ Once the training set is selected, the system uses the selected training set to
4782

4883
Label Sleuth currently includes implementations of the following machine learning models:
4984

85+
::::{tab-set}
86+
:::{tab-item} Binary mode
87+
5088
| Model name | Description | Implementation details | Hardware requirements
5189
|---|---|---|---|
5290
| `NB_OVER_BOW` | Naive Bayes over Bag-of-words | [scikit-learn](https://scikit-learn.org) implementation | - |
5391
| `NB_OVER_WORD_EMBEDDINGS` | Naive Bayes over [word embeddings*](word_embeddings) | - | - |
5492
| `SVM_OVER_BOW` | Support Vector Machine over Bag-of-words | [scikit-learn](https://scikit-learn.org) implementation | - |
5593
| `SVM_OVER_WORD_EMBEDDINGS` | Support Vector Machine over [word embeddings*](word_embeddings) | - | - |
94+
| `SVM_OVER_SBERT` | Support Vector Machine over sentence bert | - | - |
5695
| `SVM_ENSEMBLE` | Ensemble of `SVM_OVER_BOW` and `SVM_OVER_WORD_EMBEDDINGS` | - | - |
57-
| `HF_BERT` | BERT ([Devlin et al. 2018](https://arxiv.org/abs/1810.04805)) | Pytorch implementation using the [Hugging Face Transformers](https://github.com/huggingface/transformers) library | GPU _(recommended)_
58-
| `HF_XLM_ROBERTA` | XLM-R ([Conneau et al., 2019](https://arxiv.org/abs/1911.02116)) | Pytorch implementation using the [Hugging Face Transformers](https://github.com/huggingface/transformers) library | GPU _(recommended)_
96+
| `HF_BERT` | BERT ([Devlin et al. 2018](https://arxiv.org/abs/1810.04805)) | Pytorch implementation using the [Hugging Face Transformers](https://github.com/huggingface/transformers) library | GPU _(recommended)_
97+
| `HF_XLM_ROBERTA` | XLM-R ([Conneau et al., 2019](https://arxiv.org/abs/1911.02116)) | Pytorch implementation using the [Hugging Face Transformers](https://github.com/huggingface/transformers) library | GPU _(recommended)_ |
98+
| `BINARY_FLANT5XL_PT` | FLAN-T5-XL ([Chung et al., 2022](https://arxiv.org/pdf/2210.11416.pdf)) | Requires a Project ID and an API key from [WatsonX](https://www.ibm.com/watsonx). | - |
99+
100+
:::
101+
:::{tab-item} Multiclass mode
102+
| Model name | Description | Implementation details | Hardware requirements
103+
|---|---|---|---|
104+
| `MULTICLASS_SVM_BOW` | Naive Bayes over Bag-of-words | [scikit-learn](https://scikit-learn.org) implementation | - |
105+
| `MULTICLASS_SVM_WORD_EMBEDDINGS` | Naive Bayes over [word embeddings*](word_embeddings) | - | - |
106+
| `MULTICLASS_SVM_ENSEMBLE` | Support Vector Machine over Bag-of-words | [scikit-learn](https://scikit-learn.org) implementation | - |
107+
| `MULTICLASS_FLANT5XL_PT` | FLAN-T5-XL ([Chung et al., 2022](https://arxiv.org/pdf/2210.11416.pdf)) | Requires a Project ID and an API key from [WatsonX](https://www.ibm.com/watsonx). | - |
108+
:::
109+
::::
59110

60111
Within the codebase, the list of supported models can be found in Label Sleuth's [model catalog](https://github.com/label-sleuth/label-sleuth/blob/main/label_sleuth/models/core/catalog.py). Note that some model may have special hardware requirements to perform as expected (e.g., they require the presence of a GPU).
61112

@@ -64,7 +115,10 @@ Within the codebase, the list of supported models can be found in Label Sleuth's
64115

65116
### Model policies
66117

67-
The model architecture that is trained in each iteration is prescribed by the employed _model policy_. In its most basic form, a model policy is _static_, resulting in the system always using the same model for every iteration. However, model policies can also be _dynamic_, allowing the system to switch between different types of models depending on the iteration. For instance, one can create a model policy instructing Label Sleuth to use a light and fast to train model (such as SVM) for the first few iterations and then switch to more complex and slower to train model (such as BERT) for later iterations. Label Sleuth currently supports the following model policies:
118+
The model architecture that is trained in each iteration is prescribed by the employed _model policy_. In its most basic form, a model policy is _static_, resulting in the system always using the same model for every iteration. However, model policies can also be _dynamic_, allowing the system to switch between different types of models depending on the iteration. For instance, one can create a model policy instructing Label Sleuth to use a light and fast to train model (such as SVM) for the first few iterations and then switch to more complex and slower to train model (such as BERT) for later iterations. Label Sleuth currently supports the following model policies:
119+
120+
::::{tab-set}
121+
:::{tab-item} Binary mode
68122

69123
| Model policy | Model type | Description | Supported languages |
70124
|---|---|---|---|
@@ -73,8 +127,23 @@ The model architecture that is trained in each iteration is prescribed by the em
73127
| `STATIC_SVM_BOW` | Static | Use the `SVM_OVER_BOW` model in every iteration | [All languages](languages.md) |
74128
| `STATIC_SVM_WORD_EMBEDDINGS` | Static | Use the `SVM_OVER_WORD_EMBEDDINGS` model in every iteration | [All languages](languages.md) |
75129
| `STATIC_SVM_ENSEMBLE` <br /><defvalue>default</defvalue> | Static | Use the `SVM_ENSEMBLE` model in every iteration | [All languages](languages.md) |
130+
| `STATIC_SVM_SBERT` | Static | Use the `STATIC_SVM_SBERT` model in every iteration | English |
76131
| `STATIC_HF_BERT` | Static | Use the `HF_BERT` model in every iteration | English |
77132
| `STATIC_HF_XLM_ROBERTA` | Static | Use the `HF_XLM_ROBERTA` model in every iteration | [All languages](languages.md) |
133+
| `STATIC_BINARY_FLANT5XL_PT` | Static | Use the `HF_XLM_ROBERTA` model in every iteration | English |
134+
135+
:::
136+
:::{tab-item} Multiclass mode
137+
138+
| Model policy | Model type | Description | Supported languages |
139+
|---|---|---|---|
140+
| `STATIC_MULTICLASS_SVM_BOW` | Static | Use the `STATIC_MULTICLASS_SVM_BOW` model in every iteration | [All languages](languages.md) |
141+
| `STATIC_MULTICLASS_SVM_WORD_EMBEDDINGS` | Static | Use the `STATIC_MULTICLASS_SVM_WORD_EMBEDDINGS` model in every iteration | [All languages](languages.md) |
142+
| `STATIC_MULTICLASS_SVM_ENSEMBLE` | Static | Use the `STATIC_MULTICLASS_SVM_ENSEMBLE` model in every iteration | [All languages](languages.md) |
143+
| `STATIC_MULTICLASS_FLANT5XL_PT` | Static | Use the `STATIC_MULTICLASS_FLANT5XL_PT` model in every iteration | English |
144+
145+
:::
146+
::::
78147

79148
Within the codebase, the list of available model policies can be found [here](https://github.com/label-sleuth/label-sleuth/blob/main/label_sleuth/models/core/model_policies.py). The model policy can be configured by setting the `model_policy` parameter in the system's [configuration file](configuration.md).
80149

@@ -85,7 +154,7 @@ In addition to the preloaded models, Label Sleuth can be extended to support add
85154
1. Implement a new `ModelAPI`.
86155

87156
Machine learning models are integrated by adding a new implementation of the ModelAPI.
88-
The main functions are *_train()*, *load_model()* and *infer()*:
157+
The main functions are *_train()*, *load_model()*, *get_supported_languages()* and *infer()*:
89158

90159

91160
**Train** a new model.
@@ -101,22 +170,33 @@ In addition to the preloaded models, Label Sleuth can be extended to support add
101170
<br />
102171

103172
**Load** a trained model.
104-
```python
173+
174+
```python
105175
def load_model(self, model_path: str):
106176
```
177+
107178
- model_path: path to a directory containing all model components
108-
179+
109180
Returns an object that contains all the components that are necessary to perform inference (e.g., the trained model itself, the language recognized by the model, a trained vectorizer/tokenizer etc.).
110181

182+
**Get** the supported languages _(added in version 0.9.1)_.
183+
184+
```python
185+
def get_supported_languages(self) -> Set[Language]:
186+
```
187+
188+
Returns the set of languages supported by the model.
189+
111190
**Infer** a given sequence of elements and return the results.
112191

113192
```python
114193
def infer(self, model_components, items_to_infer) -> Sequence[Prediction]:
115194
```
195+
116196
- model_components: the return value of `load_model()`, i.e., an object containing all the components that are necessary to perform inference
117197
- items_to_infer: a list of dictionaries with at least the "text" field. Additional fields can be passed,
118198
e.g. *[{'text': 'text1', 'additional_field': 'value1'}, {'text': 'text2', 'additional_field': 'value2'}]*
119-
199+
120200
Returns a list of [Prediction](https://github.com/label-sleuth/label-sleuth/blob/1424a9ab01697e12396bc33fd608158d61d55e24/label_sleuth/models/core/prediction.py#L20) objects - one for each item in *items_to_infer* - where
121201
Prediction.label is a boolean and Prediction.score is a float in the range [0-1].
122202
Additional outputs can be passed by inheriting from the base Prediction class and overriding the get_predictions_class() method.

0 commit comments

Comments
 (0)