You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Oct 9, 2023. It is now read-only.
Copy file name to clipboardExpand all lines: docs/source/general/finetuning.rst
+58-43Lines changed: 58 additions & 43 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,46 +1,18 @@
1
+
.. _finetuning:
2
+
1
3
**********
2
4
Finetuning
3
5
**********
4
6
5
-
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have a pre-trained backbone that was trained on large datasets such as ImageNet, and that allows to decrease training time significantly.
6
-
7
-
The finetuning process can be split into 4 steps:
8
-
9
-
1. Train a particular neural network model on a particular dataset. For computer vision, the [ImageNet dataset](http://www.image-net.org/search?q=cat) is widely used for pre-training model. As training is costly, libraries such as [torchvision](https://pytorch.org/docs/stable/torchvision/index.html) provide popular pre-trained model architectures. These are called backbones.
10
-
11
-
2. Create a new neural network called the target model. Its architecture replicates the backbone (model from previous step) and parameters, except the latest layer which is usually replaced to fit the necessities of your data.
12
-
13
-
3. This new layer (or layers) at the end of the backbone are used to match the backbone output to the number of target categories in your data. They are commonly referred to as the head'. The head is randomly initialized whereas the backbone conserves its pre-trained weights (for example the weights from ImageNet).
14
-
15
-
4. Train the target model on a smaller target dataset. However, as the head (new layers) is untrained, the first results (gradients) will be random when training starts and could decrease the backbone performance (by changing its pre-trained parameters). Therefore, it is a good practice to "freeze" the backbone. This means the parameters of the backbone won't be updated until they are "unfrozen" a few epochs later.
16
-
17
-
18
-
.. tip:: If you have a large dataset and prefer to train from scratch, see the training guide.
19
-
20
-
You can finetune any Flash tasks on your own data in just a 3 simple steps:
21
-
22
-
1. Load your data and organize it using `Flash DataModules`. Note that different tasks have different data modules (The :class:`~flash.vision.ImageClassificationData` for image classification, :class:`~flash.text.TextClassificationData` for text classification, etc.).
23
-
24
-
2. Pick a model to run from a variety of Flash tasks: :class:`~flash.vision.ImageClassification`, :class:`~flash.text.TextClassifier`, :class:`~flash.tabular.TabularClassifier`, all optimized with the latest best practices.
25
-
26
-
3. Finetune your model using :func:`~flash.Trainer.finetune` method. You will need to choose a finetune strategy.
7
+
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have a pre-trained backbone that was already trained on large datasets such as ImageNet. Finetuning on already pretrained models decrease training time significantly.
27
8
9
+
You can finetune any Flash task on your own data in just a 3 simple steps:
28
10
29
-
Finetune options
30
-
================
11
+
1. Load your data and organize it using Flash Data Modules. Note that different tasks have different data modules (The :class:`~flash.vision.ImageClassificationData` for image classification, :class:`~flash.text.classification.data.TextClassificationData` for text classification, etc.).
31
12
32
-
Flash provides a very simple interface for finetuning through `trainer.finetune` with its `strategy` parameters.
13
+
2. Pick a model to run from a variety of Flash tasks: :class:`~flash.vision.ImageClassifier`, :class:`~flash.text.classification.model.TextClassifier`, :class:`~flash.tabular.TabularClassifier`, all optimized with the latest best practices.
33
14
34
-
Flash finetune `strategy` argument can either a string or an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning`.
35
-
36
-
Flash supports 4 builts-in Finetuning Callback accessible via those strings:
37
-
38
-
* `no_freeze`: Don't freeze anything.
39
-
* `freeze`: The parameters of the backbone won't be trainable after training starts.
40
-
* `freeze_unfreeze`: The parameters of the backbone won't be trainable when training start and then those parameters will become trainable when training epoch reaches `unfreeze_epoch`.
41
-
* `unfreeze_milestones`: The parameters of the backbone won't be trainable when training start. However, the latest layers of the backbone will become trainable when training epoch reaches the first milestone and the remaining layers when reaching the second one.
42
-
43
-
For more options, you can pass in an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning` to the `strategy` parameter.
15
+
3. Finetune your model using :func:`~flash.core.trainer.Trainer.finetune` method. You will need to choose a finetune strategy.
44
16
45
17
Once training is completed, you can use the model for inference to make predictions using the `predict` method.
46
18
@@ -65,32 +37,75 @@ Once training is completed, you can use the model for inference to make predicti
.. tip:: If you have a large dataset and prefer to train from scratch, see the :ref:`training` guide.
41
+
42
+
43
+
Finetune strategies
44
+
===================
45
+
46
+
The flash tasks contain pre-trained models trained on large datasets such as `ImageNet <http://www.image-net.org/>`_, which contains millions of images. These models are called **backbones**. This will be used as the starting point for finetuning.
47
+
48
+
The model needs to be adapted or refined for the new data available for the task. Usually, the last layers of the backbone need to be modified, to match the backbone output to the number of target classes of the new data. These layers are commonly referred to as the **head**.
49
+
For example, our backbone might be trained to classify 10 types of animals, but maybe our new dataset only contains images of bees and ants, so we would have to modify our final layer to fit just 2 classes.
50
+
The head is randomly initialized whereas the backbone conserves its pre-trained weights.
51
+
52
+
The :func:`~flash.core.trainer.Trainer.finetune` method trains the new modified model using the new dataset. As the head (new layers) is untrained, the first results (gradients) will be random when training starts and could decrease the backbone performance (by changing its pre-trained parameters). Therefore, it is a good practice to "freeze" the backbone, meaning the parameters of the backbone won't be updated until they are "unfrozen" a few epochs later.
53
+
54
+
You can choose a finetuning strategy using :func:`~flash.core.trainer.Trainer.finetune` `strategy` parameter. Flash finetune `strategy` argument can either a string or an instance of :class:`~flash.core.finetuning.FlashBaseFinetuning`.
55
+
56
+
Flash supports 2 builts-in Finetuning strategies, that can be passed as strings:
57
+
58
+
* `no_freeze`: Don't freeze anything, the backbone parameters can be modified during finetuning.
59
+
* `freeze`: The parameters of the backbone won't be modified during finetuning.
For more options, you can pass in an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning` to the `strategy` parameter.
71
+
68
72
69
73
==========================
70
74
Custom callback finetuning
71
75
==========================
72
76
73
-
You can pass in the built in callbacks for more customization:
77
+
For more advanced finetuning, you can use flash built-in finetuning callbacks.
74
78
75
-
.. code-block:: python
79
+
* :class:`~flash.core.finetuning.FreezeUnfreeze`: The backbone parameters will be frozen for a given number of epochs (by default the `unfreeze_epoch` is set to 10).
* :class:`~flash.core.finetuning.UnfreezeMilestones`: This strategy define 2 milestones, one milestone (epoch number) to unfreeze the last layers of the backbone, and a second milestone to unfreeze the remaining layers. For example, by default the first milestone is 5 and the second is 10. So for the first 4 epochs, the backbone parameters will be frozen. In epochs 5-9, only the last layers (5 by deafult) can be trained. After the 10thg epoch, all parameters in all layers can be trained.
92
+
93
+
94
+
.. code-block:: python
95
+
96
+
# import UnfreezeMilestones
97
+
from flash.core.finetuning import UnfreezeMilestones
98
+
99
+
# finetune for 10 epochs. Backbone will be frozen for 3 epochs. The last 2 layers will be unfrozen for the first 4 epochs,
100
+
# and then the rest will be unfrozen on the 8th epoch
For even more customization, create your own finetuning callback.
108
+
For even more customization, create your own finetuning callback. Learn more about callbacks `here <https://pytorch-lightning.readthedocs.io/en/stable/callbacks.html>`_.
Copy file name to clipboardExpand all lines: docs/source/general/training.rst
+6-2Lines changed: 6 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,8 +1,12 @@
1
+
2
+
.. _training:
3
+
1
4
*********************
2
5
Training from scratch
3
6
*********************
4
7
5
-
Some Flash tasks have been pretrained on large data sets, to accelerate your training (calling the :func:`~flash.Trainer.finetune` method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task). If you want to train the task from scratch instead, pass `pretrained=False` parameter when creating your task. Then, use the :func:`~flash.Trainer.fit` method to train your model.
8
+
Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the :func:`~flash.core.trainer.Trainer.finetune` method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task. If you want to train the task from scratch instead, pass `pretrained=False` parameter when creating your task. Then, use the :func:`~flash.core.trainer.Trainer.fit` method to train your model.
9
+
6
10
7
11
.. code-block:: python
8
12
@@ -71,7 +75,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such
71
75
# Train on TPUs
72
76
trainer.fit(tpu_cores=8)
73
77
74
-
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer `here <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`.
78
+
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer `here <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`_.
0 commit comments