Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit ff3f1ce

Browse files
teddykokercarmoccatchaton
authored
Custom Task Tutorial (#42)
* started nb * just need to add words * update * more words * update rst with notebook * Update docs/source/custom_task.rst Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update docs/source/custom_task.rst Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update docs/source/custom_task.rst Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * link to datapipeline page * resolve doc Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: chaton <thomas@grid.ai>
1 parent 923f455 commit ff3f1ce

3 files changed

Lines changed: 411 additions & 4 deletions

File tree

docs/source/custom_task.rst

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,140 @@
1-
############
2-
Write a Task
3-
############
1+
Tutorial: Creating a Custom Task
2+
================================
43

4+
In this tutorial we will go over the process of creating a custom task,
5+
along with a custom data module.
56

6-
**Detailed guide comming soon!**
7+
.. code:: python
8+
9+
import flash
10+
11+
import torch
12+
from torch.utils.data import TensorDataset, DataLoader
13+
from torch import nn
14+
from sklearn import datasets
15+
from sklearn.model_selection import train_test_split
16+
17+
The Task: Linear regression
18+
---------------------------
19+
20+
Here we create a basic linear regression task by subclassing
21+
``flash.Task``. For the majority of tasks, you will likely only need to
22+
override the ``__init__`` and ``forward`` methods.
23+
24+
.. code:: python
25+
26+
class LinearRegression(flash.Task):
27+
def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
28+
# what kind of model do we want?
29+
model = nn.Linear(num_inputs, 1)
30+
31+
# what loss function do we want?
32+
loss_fn = torch.nn.functional.mse_loss
33+
34+
# what optimizer to do we want?
35+
optimizer = torch.optim.SGD
36+
37+
super().__init__(
38+
model=model,
39+
loss_fn=loss_fn,
40+
optimizer=optimizer,
41+
metrics=metrics,
42+
learning_rate=learning_rate,
43+
)
44+
45+
def forward(self, x):
46+
# we don't actually need to override this method for this example
47+
return self.model(x)
48+
49+
Where is the training step?
50+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
52+
Most models can be trained simply by passing the output of ``forward``
53+
to the supplied ``loss_fn``, and then passing the resulting loss to the
54+
supplied ``optimizer``. If you need a more custom configuration, you can
55+
override ``step`` (which is called for training, validation, and
56+
testing) or override ``training_step``, ``validation_step``, and
57+
``test_step`` individually. These methods behave identically to PyTorch
58+
Lightning’s
59+
`methods <https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#methods>`__.
60+
61+
The Data
62+
--------
63+
64+
For a task you will likely need a specific way of loading data. For this
65+
example, lets say we want a ``flash.DataModule`` to be used explicitly
66+
for the prediction of diabetes disease progression. We can create this
67+
``DataModule`` below, wrapping the scikit-learn `Diabetes
68+
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.
69+
70+
.. code:: python
71+
72+
class DiabetesPipeline(flash.core.data.TaskDataPipeline):
73+
def after_uncollate(self, samples):
74+
return [f"disease progression: {float(s):.2f}" for s in samples]
75+
76+
class DiabetesData(flash.DataModule):
77+
def __init__(self, batch_size=64, num_workers=0):
78+
x, y = datasets.load_diabetes(return_X_y=True)
79+
x = torch.from_numpy(x).float()
80+
y = torch.from_numpy(y).float().unsqueeze(1)
81+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)
82+
83+
train_ds = TensorDataset(x_train, y_train)
84+
test_ds = TensorDataset(x_test, y_test)
85+
86+
super().__init__(
87+
train_ds=train_ds,
88+
test_ds=test_ds,
89+
batch_size=batch_size,
90+
num_workers=num_workers
91+
)
92+
self.num_inputs = x.shape[1]
93+
94+
@staticmethod
95+
def default_pipeline():
96+
return DiabetesPipeline()
97+
98+
You’ll notice we added a ``DataPipeline``, which will be used when we
99+
call ``.predict()`` on our model. In this case we want to nicely format
100+
our ouput from the model with the string ``"disease progression"``, but
101+
you could do any sort of post processing you want (see :ref:`datapipeline`).
102+
103+
Fit
104+
---
105+
106+
Like any Flash Task, we can fit our model using the ``flash.Trainer`` by
107+
supplying the task itself, and the associated data:
108+
109+
.. code:: python
110+
111+
data = DiabetesData()
112+
model = LinearRegression(num_inputs=data.num_inputs)
113+
114+
trainer = flash.Trainer(max_epochs=1000)
115+
trainer.fit(model, data)
116+
117+
With a trained model we can now perform inference. Here we will use a
118+
few examples from the test set of our data:
119+
120+
.. code:: python
121+
122+
predict_data = torch.tensor([
123+
[ 0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403],
124+
[-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072],
125+
[ 0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072],
126+
[-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
127+
[-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094]])
128+
129+
model.predict(predict_data)
130+
131+
Because of our custom data pipeline’s ``after_uncollate`` method, we
132+
will get a nicely formatted output like the following:
133+
134+
.. code::
135+
136+
['disease progression: 155.90',
137+
'disease progression: 156.59',
138+
'disease progression: 152.69',
139+
'disease progression: 149.05',
140+
'disease progression: 150.90']

docs/source/general/data.rst

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,64 @@
22
Data
33
####
44

5+
.. _datapipeline:
6+
7+
DataPipeline
8+
------------
9+
10+
To make tasks work for inference, one must create a ``DataPipeline``.
11+
The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:
12+
13+
.. code:: python
14+
15+
class DataPipeline:
16+
"""
17+
This class purpose is to facilitate the conversion of raw data to processed or batched data and back.
18+
Several hooks are provided for maximum flexibility.
19+
20+
collate_fn:
21+
- before_collate
22+
- collate
23+
- after_collate
24+
25+
uncollate_fn:
26+
- before_uncollate
27+
- uncollate
28+
- after_uncollate
29+
"""
30+
31+
def before_collate(self, samples: Any) -> Any:
32+
"""Override to apply transformations to samples"""
33+
return samples
34+
35+
def collate(self, samples: Any) -> Any:
36+
"""Override to convert a set of samples to a batch"""
37+
if not isinstance(samples, Tensor):
38+
return default_collate(samples)
39+
return samples
40+
41+
def after_collate(self, batch: Any) -> Any:
42+
"""Override to apply transformations to the batch"""
43+
return batch
44+
45+
def before_uncollate(self, batch: Any) -> Any:
46+
"""Override to apply transformations to the batch"""
47+
return batch
48+
49+
def uncollate(self, batch: Any) -> ny:
50+
"""Override to convert a batch to a set of samples"""
51+
samples = batch
52+
return samples
53+
54+
def after_uncollate(self, samples: Any) -> Any:
55+
"""Override to apply transformations to samples"""
56+
return samplesA
57+
58+
59+
60+
61+
62+
563
Use these utilities to download data.
664

765
-----

0 commit comments

Comments
 (0)