Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 923f455

Browse files
edenlightningwilliamFalconcarmoccateddykokerBorda
authored
Another pass at docs (#24)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * docs (#16) * docs (#17) * Update README.md * Update README.md * Update README.md * fix docs * autosummary * Update custom_task.rst * sync * docs * docs * docs * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * Update docs/source/quickstart.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tabular_classification.rst * Update tabular_classification.rst * finetuning Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai>
1 parent edb65c7 commit 923f455

10 files changed

Lines changed: 209 additions & 257 deletions

docs/source/custom_task.rst

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,5 @@
22
Write a Task
33
############
44

5-
You can create your own Flash task to if you want to solve any other deep learning problem.
65

7-
See this example for defining a linear classifier task:
8-
9-
.. testcode::
10-
11-
import torch
12-
import torch.nn.functional as F
13-
from pytorch_lightning.metrics import Accuracy
14-
from flash.core.classification import ClassificationTask
15-
16-
class LinearClassifier(ClassificationTask):
17-
def __init__(
18-
self,
19-
num_inputs: int,
20-
num_classes: int,
21-
loss_fn=F.cross_entropy,
22-
optimizer=torch.optim.SGD,
23-
metrics=[Accuracy()],
24-
learning_rate=1e-3,
25-
):
26-
super().__init__(model=None,
27-
loss_fn=loss_fn,
28-
optimizer=optimizer,
29-
metrics=metrics,
30-
learning_rate=learning_rate,
31-
)
32-
self.save_hyperparameters()
33-
34-
self.linear = torch.nn.Linear(num_inputs, num_classes)
35-
36-
def forward(self, x):
37-
return self.linear(x)
6+
**Detailed guide comming soon!**

docs/source/general/finetuning.rst

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,18 @@
1+
.. _finetuning:
2+
13
**********
24
Finetuning
35
**********
46

5-
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have a pre-trained backbone that was trained on large datasets such as ImageNet, and that allows to decrease training time significantly.
6-
7-
The finetuning process can be split into 4 steps:
8-
9-
1. Train a particular neural network model on a particular dataset. For computer vision, the [ImageNet dataset](http://www.image-net.org/search?q=cat) is widely used for pre-training model. As training is costly, libraries such as [torchvision](https://pytorch.org/docs/stable/torchvision/index.html) provide popular pre-trained model architectures. These are called backbones.
10-
11-
2. Create a new neural network called the target model. Its architecture replicates the backbone (model from previous step) and parameters, except the latest layer which is usually replaced to fit the necessities of your data.
12-
13-
3. This new layer (or layers) at the end of the backbone are used to match the backbone output to the number of target categories in your data. They are commonly referred to as the head'. The head is randomly initialized whereas the backbone conserves its pre-trained weights (for example the weights from ImageNet).
14-
15-
4. Train the target model on a smaller target dataset. However, as the head (new layers) is untrained, the first results (gradients) will be random when training starts and could decrease the backbone performance (by changing its pre-trained parameters). Therefore, it is a good practice to "freeze" the backbone. This means the parameters of the backbone won't be updated until they are "unfrozen" a few epochs later.
16-
17-
18-
.. tip:: If you have a large dataset and prefer to train from scratch, see the training guide.
19-
20-
You can finetune any Flash tasks on your own data in just a 3 simple steps:
21-
22-
1. Load your data and organize it using `Flash DataModules`. Note that different tasks have different data modules (The :class:`~flash.vision.ImageClassificationData` for image classification, :class:`~flash.text.TextClassificationData` for text classification, etc.).
23-
24-
2. Pick a model to run from a variety of Flash tasks: :class:`~flash.vision.ImageClassification`, :class:`~flash.text.TextClassifier`, :class:`~flash.tabular.TabularClassifier`, all optimized with the latest best practices.
25-
26-
3. Finetune your model using :func:`~flash.Trainer.finetune` method. You will need to choose a finetune strategy.
7+
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have a pre-trained backbone that was already trained on large datasets such as ImageNet. Finetuning on already pretrained models decrease training time significantly.
278

9+
You can finetune any Flash task on your own data in just a 3 simple steps:
2810

29-
Finetune options
30-
================
11+
1. Load your data and organize it using Flash Data Modules. Note that different tasks have different data modules (The :class:`~flash.vision.ImageClassificationData` for image classification, :class:`~flash.text.classification.data.TextClassificationData` for text classification, etc.).
3112

32-
Flash provides a very simple interface for finetuning through `trainer.finetune` with its `strategy` parameters.
13+
2. Pick a model to run from a variety of Flash tasks: :class:`~flash.vision.ImageClassifier`, :class:`~flash.text.classification.model.TextClassifier`, :class:`~flash.tabular.TabularClassifier`, all optimized with the latest best practices.
3314

34-
Flash finetune `strategy` argument can either a string or an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning`.
35-
36-
Flash supports 4 builts-in Finetuning Callback accessible via those strings:
37-
38-
* `no_freeze`: Don't freeze anything.
39-
* `freeze`: The parameters of the backbone won't be trainable after training starts.
40-
* `freeze_unfreeze`: The parameters of the backbone won't be trainable when training start and then those parameters will become trainable when training epoch reaches `unfreeze_epoch`.
41-
* `unfreeze_milestones`: The parameters of the backbone won't be trainable when training start. However, the latest layers of the backbone will become trainable when training epoch reaches the first milestone and the remaining layers when reaching the second one.
42-
43-
For more options, you can pass in an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning` to the `strategy` parameter.
15+
3. Finetune your model using :func:`~flash.core.trainer.Trainer.finetune` method. You will need to choose a finetune strategy.
4416

4517
Once training is completed, you can use the model for inference to make predictions using the `predict` method.
4618

@@ -65,32 +37,75 @@ Once training is completed, you can use the model for inference to make predicti
6537
trainer = flash.Trainer()
6638
trainer.finetune(model, data, strategy="no_freeze")
6739
40+
.. tip:: If you have a large dataset and prefer to train from scratch, see the :ref:`training` guide.
41+
42+
43+
Finetune strategies
44+
===================
45+
46+
The flash tasks contain pre-trained models trained on large datasets such as `ImageNet <http://www.image-net.org/>`_, which contains millions of images. These models are called **backbones**. This will be used as the starting point for finetuning.
47+
48+
The model needs to be adapted or refined for the new data available for the task. Usually, the last layers of the backbone need to be modified, to match the backbone output to the number of target classes of the new data. These layers are commonly referred to as the **head**.
49+
For example, our backbone might be trained to classify 10 types of animals, but maybe our new dataset only contains images of bees and ants, so we would have to modify our final layer to fit just 2 classes.
50+
The head is randomly initialized whereas the backbone conserves its pre-trained weights.
51+
52+
The :func:`~flash.core.trainer.Trainer.finetune` method trains the new modified model using the new dataset. As the head (new layers) is untrained, the first results (gradients) will be random when training starts and could decrease the backbone performance (by changing its pre-trained parameters). Therefore, it is a good practice to "freeze" the backbone, meaning the parameters of the backbone won't be updated until they are "unfrozen" a few epochs later.
53+
54+
You can choose a finetuning strategy using :func:`~flash.core.trainer.Trainer.finetune` `strategy` parameter. Flash finetune `strategy` argument can either a string or an instance of :class:`~flash.core.finetuning.FlashBaseFinetuning`.
55+
56+
Flash supports 2 builts-in Finetuning strategies, that can be passed as strings:
57+
58+
* `no_freeze`: Don't freeze anything, the backbone parameters can be modified during finetuning.
59+
* `freeze`: The parameters of the backbone won't be modified during finetuning.
60+
61+
.. code-block:: python
62+
63+
# using the freeze strategy
64+
trainer.finetune(model, data, strategy="freeze")
65+
66+
# using the no_freeze strategy
67+
trainer.finetune(model, data, strategy="no_freeze")
68+
69+
70+
For more options, you can pass in an instance of :class:`~python_lightning.callbacks.finetuning.BaseFinetuning` to the `strategy` parameter.
71+
6872

6973
==========================
7074
Custom callback finetuning
7175
==========================
7276

73-
You can pass in the built in callbacks for more customization:
77+
For more advanced finetuning, you can use flash built-in finetuning callbacks.
7478

75-
.. code-block:: python
79+
* :class:`~flash.core.finetuning.FreezeUnfreeze`: The backbone parameters will be frozen for a given number of epochs (by default the `unfreeze_epoch` is set to 10).
7680

77-
# finetune for 10 epochs
78-
trainer = flash.Trainer()
79-
trainer.finetune(model, data, strategy="freeze_unfreeze")
8081

81-
# or import FreezeUnfreeze
82+
.. code-block:: python
83+
84+
# import FreezeUnfreeze
8285
from flash.core.finetuning import FreezeUnfreeze
8386
8487
# finetune for 10 epochs. Backbone will be frozen for 5 epochs.
85-
trainer = flash.Trainer()
88+
trainer = flash.Trainer(max_epochs=10)
8689
trainer.finetune(model, data, strategy=FreezeUnfreeze(unfreeze_epoch=5))
8790
91+
* :class:`~flash.core.finetuning.UnfreezeMilestones`: This strategy define 2 milestones, one milestone (epoch number) to unfreeze the last layers of the backbone, and a second milestone to unfreeze the remaining layers. For example, by default the first milestone is 5 and the second is 10. So for the first 4 epochs, the backbone parameters will be frozen. In epochs 5-9, only the last layers (5 by deafult) can be trained. After the 10thg epoch, all parameters in all layers can be trained.
92+
93+
94+
.. code-block:: python
95+
96+
# import UnfreezeMilestones
97+
from flash.core.finetuning import UnfreezeMilestones
98+
99+
# finetune for 10 epochs. Backbone will be frozen for 3 epochs. The last 2 layers will be unfrozen for the first 4 epochs,
100+
# and then the rest will be unfrozen on the 8th epoch
101+
trainer = flash.Trainer(max_epochs=10)
102+
trainer.finetune(model, data, strategy=UnfreezeMilestones(unfreeze_milestones=(5,8), num_layers=2))
88103
89104
90105
Custom callback finetuning
91106
==========================
92107

93-
For even more customization, create your own finetuning callback.
108+
For even more customization, create your own finetuning callback. Learn more about callbacks `here <https://pytorch-lightning.readthedocs.io/en/stable/callbacks.html>`_.
94109

95110
.. code-block:: python
96111

docs/source/general/predictions.rst

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@ You can pass in a sample of data (image file path, a string of text, etc) to the
1515

1616
.. code-block:: python
1717
18-
from flash import Trainer
19-
from flash import download_data
20-
from flash.vision import ImageClassificationData, ImageClassifier
18+
from flash import Trainer
19+
from flash.core.data import download_data
20+
from flash.vision import ImageClassificationData, ImageClassifier
2121
22-
# 1. Download the data set
23-
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
2422
25-
# 2. Load the model from a checkpoint
26-
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
23+
# 1. Download the data set
24+
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
2725
28-
# 3. Predict whether the image contains an ant or a bee
29-
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
30-
print(predictions)
26+
# 2. Load the model from a checkpoint
27+
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
28+
29+
# 3. Predict whether the image contains an ant or a bee
30+
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
31+
print(predictions)
3132
3233
3334
@@ -36,45 +37,19 @@ Predict on a csv file
3637

3738
.. code-block:: python
3839
39-
from flash import download_data
40-
from flash.tabular import TabularClassifier
41-
42-
# 1. Download the data
43-
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
44-
45-
# 2. Load the model from a checkpoint
46-
model = TabularClassifier.load_from_checkpoint(
47-
"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt"
48-
)
49-
50-
# 3. Generate predictions from a csv file! Who would survive?
51-
predictions = model.predict("data/titanic/titanic.csv")
52-
print(predictions)
53-
54-
55-
Distributed predictions
56-
=======================
57-
58-
For more advanced options like distributed inference, you need to use the :func:`~flash.core.trainer.Trainer.predict` method.
59-
60-
.. code-block:: python
61-
62-
from flash import Trainer
63-
from flash import download_data
64-
from flash.vision import ImageClassificationData, ImageClassifier
65-
66-
67-
# 1. Download the data set
68-
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
69-
70-
# 2. Load the model from a checkpoint
71-
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
40+
from flash.core.data import download_data
41+
from flash.tabular import TabularClassifier
7242
43+
# 1. Download the data
44+
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
7345
74-
# 3b. Generate predictions with a whole folder!
75-
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
46+
# 2. Load the model from a checkpoint
47+
model = TabularClassifier.load_from_checkpoint(
48+
"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt"
49+
)
7650
77-
predictions = Trainer().predict(model, datamodule=datamodule)
78-
print(predictions)
51+
# 3. Generate predictions from a csv file! Who would survive?
52+
predictions = model.predict("data/titanic/titanic.csv")
53+
print(predictions)
7954
8055

docs/source/general/training.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
2+
.. _training:
3+
14
*********************
25
Training from scratch
36
*********************
47

5-
Some Flash tasks have been pretrained on large data sets, to accelerate your training (calling the :func:`~flash.Trainer.finetune` method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task). If you want to train the task from scratch instead, pass `pretrained=False` parameter when creating your task. Then, use the :func:`~flash.Trainer.fit` method to train your model.
8+
Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the :func:`~flash.core.trainer.Trainer.finetune` method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task. If you want to train the task from scratch instead, pass `pretrained=False` parameter when creating your task. Then, use the :func:`~flash.core.trainer.Trainer.fit` method to train your model.
9+
610

711
.. code-block:: python
812
@@ -71,7 +75,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such
7175
# Train on TPUs
7276
trainer.fit(tpu_cores=8)
7377
74-
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer `here <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`.
78+
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer `here <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`_.
7579

7680

7781
Trainer API

0 commit comments

Comments
 (0)