Skip to content

Commit 3d24d5a

Browse files
authored
add checkpoint docs (#140)
* add checkpoint docs * fix formatting in checkpoint docs * fix slack link
1 parent 606cf34 commit 3d24d5a

File tree

11 files changed

+465
-9
lines changed

11 files changed

+465
-9
lines changed

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ Metaflow makes it easy to build and manage real-life data science, AI, and ML pr
4040
- [Computing at Scale](scaling/remote-tasks/introduction)
4141
- [Managing Dependencies](scaling/dependencies)
4242
- [Dealing with Failures](scaling/failures)
43+
- [Checkpointing Progress](scaling/checkpoint/introduction)*New*
4344
- [Loading and Storing Data](scaling/data)
44-
- [Accessing Secrets](scaling/secrets)
4545
- [Organizing Results](scaling/tagging)
46+
- [Accessing Secrets](scaling/secrets)
4647

4748
## III. Deploying to Production
4849

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
2+
# Checkpoints in ML/AI libraries
3+
4+
Let's explore how `@checkpoint` works in a real-world scenario when checkpointing training progress with popular ML
5+
libraries.
6+
7+
## Checkpointing XGBoost
8+
9+
Like many other ML libraries, [XGBoost](https://xgboost.readthedocs.io/en/stable/) allows you to define custom callbacks
10+
that are called periodically during training. We can create a custom checkpointer that saves the model to a file, using
11+
`pickle`, [as recommended by XGBoost](https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html), and calls
12+
`current.checkpoint.save()` to persist it.
13+
14+
Save this snippet in a file, `xgboost_checkpointer.py`:
15+
16+
```python
17+
import os
18+
import pickle
19+
from metaflow import current
20+
import xgboost
21+
22+
class Checkpointer(xgboost.callback.TrainingCallback):
23+
24+
@classmethod
25+
def _path(cls):
26+
return os.path.join(current.checkpoint.directory, 'xgb_cp.pkl')
27+
28+
def __init__(self, interval=10):
29+
self._interval = interval
30+
31+
def after_iteration(self, model, epoch, evals_log):
32+
if epoch > 0 and epoch % self._interval == 0:
33+
with open(self._path(), 'wb') as f:
34+
pickle.dump(model, f)
35+
current.checkpoint.save()
36+
37+
@classmethod
38+
def load(cls):
39+
with open(cls._path(), 'rb') as f:
40+
return pickle.load(f)
41+
```
42+
43+
:::tip
44+
Make sure that the checkpoint directory doesn't accumulate files across invocations, which would make the `save`
45+
operation become slower over time. Either overwrite the same files or clean up the directory between checkpoints.
46+
The `save` call will create a uniquely named checkpoint directory automatically, so you can keep overwriting files
47+
across iterations.
48+
:::
49+
50+
We can then train an XGboost model using `Checkpointer`:
51+
52+
```python
53+
from metaflow import FlowSpec, step, current, Flow,\
54+
Parameter, conda, retry, checkpoint, card, timeout
55+
56+
class CheckpointXGBoost(FlowSpec):
57+
rounds = Parameter("rounds", help="number of boosting rounds", default=128)
58+
59+
@conda(packages={"scikit-learn": "1.6.1"})
60+
@step
61+
def start(self):
62+
from sklearn.datasets import load_breast_cancer
63+
64+
self.X, self.y = load_breast_cancer(return_X_y=True)
65+
self.next(self.train)
66+
67+
@timeout(seconds=15)
68+
@conda(packages={"xgboost": "2.1.4"})
69+
@card
70+
@retry
71+
@checkpoint
72+
@step
73+
def train(self):
74+
import xgboost as xgb
75+
from xgboost_checkpointer import Checkpointer
76+
77+
if current.checkpoint.is_loaded:
78+
cp_model = Checkpointer.load()
79+
cp_rounds = cp_model.num_boosted_rounds()
80+
print(f"Checkpoint was trained for {cp_rounds} rounds")
81+
else:
82+
cp_model = None
83+
cp_rounds = 0
84+
85+
model = xgb.XGBClassifier(
86+
n_estimators=self.rounds - cp_rounds,
87+
eval_metric="logloss",
88+
callbacks=[Checkpointer()])
89+
model.fit(self.X, self.y, eval_set=[(self.X, self.y)], xgb_model=cp_model)
90+
91+
assert model.get_booster().num_boosted_rounds() == self.rounds
92+
print("Training completed!")
93+
self.next(self.end)
94+
95+
@step
96+
def end(self):
97+
pass
98+
99+
if __name__ == "__main__":
100+
CheckpointXGBoost()
101+
```
102+
103+
You can run the flow, saved to `xgboostflow.py`, as usual:
104+
105+
```
106+
python xgboostflow.py --environment=conda run
107+
```
108+
109+
To demonstrate checkpoints in action, [the `@timeout`
110+
decorator](/scaling/failures#timing-out-with-the-timeout-decorator) interrupts training every 15 seconds.
111+
You can adjust the time
112+
depending on how fast the training progresses on your workstation. The `@retry` decorator will then start the task
113+
again, allowing `@checkpoint` to load the latest checkpoint and resume training.
114+
115+
## Checkpointing PyTorch
116+
117+
Using `@checkpoint` with [PyTorch](https://pytorch.org/) is straightforward. Within your training loop, periodically
118+
serialize the model and use `current.checkpoint.save()` to create a checkpoint, along these lines:
119+
120+
```python
121+
model_path = os.path.join(current.checkpoint.directory, 'model')
122+
torch.save(model.state_dict(), model_path)
123+
current.checkpoint.save()
124+
```
125+
126+
Before starting training, check for an available checkpoint and load the model from it if found:
127+
128+
```python
129+
if current.checkpoint.is_loaded:
130+
model.load_state_dict(torch.load(model_path))
131+
```
132+
133+
Take a look at [this reference repository for a complete
134+
example](https://github.com/outerbounds/metaflow-checkpoint-examples/tree/master/mnist_torch_vanilla) showing this pattern in action, in addition to examples for many other frameworks.
135+
136+
## Checkpointing GenAI/LLM fine-tuning
137+
138+
Fine-tuning large language models and other large foundation models for generative AI can easily take hours, running on expensive GPU instances. Take a look at the following examples to learn how `@checkpoint` can be applied to various fine-tuning use cases:
139+
140+
- [Finetuning a LoRA from a model downloaded from
141+
HuggingFace](https://github.com/outerbounds/metaflow-checkpoint-examples/tree/master/lora_huggingface)
142+
143+
- [Finetuning an LLM using LLaMA
144+
Factory](https://github.com/outerbounds/metaflow-checkpoint-examples/tree/master/llama_factory)
145+
146+
- [Finetuning an LLM and serve it with NVIDIA
147+
NIM](https://github.com/outerbounds/metaflow-checkpoint-examples/tree/master/nim_lora)
148+
149+
## Checkpointing distributed workloads
150+
151+
[Metaflow supports distributed training](/scaling/remote-tasks/distributed-computing) and other distributed workloads
152+
which execute across multiple instances in a cluster. When training large models over extended periods across multiple
153+
instances, which greatly increases the likelihood of hitting spurious failures, checkpointing becomes essential to
154+
ensure efficient recovery.
155+
156+
Checkpointing works smoothly when only the control node in a training cluster is designated to handle it, preventing
157+
race conditions that could arise from multiple instances attempting to save progress simultaneously. For reference,
158+
[take a look at this
159+
example](https://github.com/outerbounds/metaflow-checkpoint-examples/tree/master/cifar_distributed) that uses [PyTorch Data Distributed Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) mode to train a vision model on CIFAR-10 dataset, checkpointing progress with `@checkpoint`.
160+
161+
:::info
162+
Large-scale distributed computing can be challenging. If you need help setting up `@checkpoint` in distributed setups, don’t hesitate to [ask for guidance on Metaflow Slack](http://slack.outerbounds.co).
163+
:::
164+
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Checkpointing Progress
2+
3+
Metaflow artifacts are used to persist models, dataframes, and other Python objects upon task completion. They
4+
checkpoint the flow's state at step boundaries, enabling you to inspect results of a task with
5+
[the Client API](/metaflow/client) and [`resume` execution from any
6+
step](/metaflow/debugging#how-to-use-the-resume-command).
7+
8+
In some cases, a task may require a long time to execute. For example, training a model on an expensive GPU instance
9+
(or across a cluster) may take several hours or even days. In such situations, persisting the final model only upon
10+
task completion is not sufficient. Instead, it is advisable to checkpoint progress periodically while the task is
11+
executing, so you won’t lose hours of work in the event of a failure.
12+
13+
You can use a Metaflow extension, `metaflow-checkpoint`, to create and use in-task checkpoints easily: Just add
14+
`@checkpoint` and call `current.checkpoint.save` to checkpoint progress periodically. A major benefit of `@checkpoint`
15+
is that it keeps checkpoints organized automatically alongside Metaflow tasks, so you don’t have to deal with saving,
16+
loading, organizing, and keeping track of checkpoint files manually.
17+
18+
Notably, `@checkpoint` integrates seamlessly with popular AI and ML frameworks such as XGBoost, PyTorch, and others, as
19+
described below. For more background, read [the announcement blog post for
20+
`@checkpoint`](https://outerbounds.com/blog/indestructible-training-with-checkpoint).
21+
22+
:::info
23+
The `@checkpoint` decorator is not a built-in part of core Metaflow yet, so you have to install it separately as
24+
described below. Also its APIs may change in the future, in contrast to the APIs of core Metaflow which are
25+
guaranteed to stay backwards compatible. Please share your feedback on
26+
[Metaflow Slack](http://slack.outerbounds.co)!
27+
:::
28+
29+
## Installing `@checkpoint`
30+
31+
To use the `@checkpoint` extension, install it with
32+
```
33+
pip install metaflow-checkpoint
34+
```
35+
in the environments where
36+
you develop and deploy Metaflow code. Metaflow packages extensions for remote execution automatically, so you don’t
37+
need to include it in container images used to run tasks remotely.
38+
39+
## Using `@checkpoint`
40+
41+
The `@checkpoint` decorator operates by persisting files in a local directory to the Metaflow datastore. This makes it
42+
directly compatible with many popular ML and AI frameworks that support persisting checkpoints on disk natively.
43+
44+
Let’s demonstrate the functionality with this simple flow that tries to increment a counter in a loop that fails 20% of
45+
the time. Thanks to `@checkpoint` and `@retry`, the `flaky_count` step recovers from exceptions and continues counting
46+
from the latest checkpoint, succeeding eventually:
47+
48+
```python
49+
import os
50+
import random
51+
from metaflow import FlowSpec, current, step, retry, checkpoint
52+
53+
class CheckpointCounterFlow(FlowSpec):
54+
@step
55+
def start(self):
56+
self.counter = 0
57+
self.next(self.flaky_count)
58+
59+
@checkpoint
60+
@retry
61+
@step
62+
def flaky_count(self):
63+
cp_path = os.path.join(current.checkpoint.directory, "counter")
64+
65+
def _save_counter():
66+
print(f"Checkpointing counter value {self.counter}")
67+
with open(cp_path, "w") as f:
68+
f.write(str(self.counter))
69+
self.latest_checkpoint = current.checkpoint.save()
70+
71+
def _load_counter():
72+
if current.checkpoint.is_loaded:
73+
with open(cp_path) as f:
74+
self.counter = int(f.read())
75+
print(f"Checkpoint loaded!")
76+
77+
_load_counter()
78+
print("Counter is now", self.counter)
79+
80+
while self.counter < 10:
81+
self.counter += 1
82+
if self.counter % 2 == 0:
83+
_save_counter()
84+
85+
if random.random() < 0.2:
86+
raise Exception("Bad luck! Try again!")
87+
88+
self.next(self.end)
89+
90+
@step
91+
def end(self):
92+
print("Final counter", self.counter)
93+
94+
if __name__ == "__main__":
95+
CheckpointCounterFlow()
96+
```
97+
98+
After installing the `metaflow-checkpoint` extension, you can run the flow as usual:
99+
```
100+
python checkpoint_counter.py run
101+
```
102+
The flow demonstrates typical usage of `@checkpoint`:
103+
104+
- `@checkpoint` initializes a temporary directory, `current.checkpoint.directory`, which you can use as a staging area for data to be checkpointed.
105+
106+
- By default, `@checkpoint` loads the latest task-specific checkpoint in the directory automatically. If a checkpoint is found, `current.checkpoint.is_loaded` is set to `True`, so you can initialize processing with previously stored data, like loading the latest value of `counter` in this case.
107+
108+
- Periodically during processing, you can save any data required to resume processing in the staging directory and call `current.checkpoint.save()` to persist it in the datastore.
109+
110+
- We save a reference to the latest checkpoint in an artifact, `latest_checkpoint`, which allows us to find and load particular checkpoints later, as explained later in this document.
111+
112+
Behind the scenes, besides loading and storing data efficiently, `@checkpoint` takes care of scoping the checkpoint data to specific tasks. You can use `@checkpoint` in many parallel tasks, even in a foreach, knowing that `@checkpoint` will automatically load checkpoints specific to each branch. It also makes it possible to use checkpoints across runs, as described in [Deciding what checkpoint to use](/scaling/checkpoint/selecting-checkpoints).
113+
114+
## Observing `@checkpoint` through cards
115+
116+
Try running the above flow with [the default Metaflow
117+
`@card`](/metaflow/visualizing-results/effortless-task-inspection-with-default-cards):
118+
```
119+
python checkpoint_counter.py run --with card
120+
```
121+
If a step decorated with `@checkpoint` has a card enabled, it will add information about checkpoints loaded and stored in the card. For instance, the screenshot below shows a card associated with the second attempt (`[Attempt: 1]` at the top of the card) which loaded a checkpoint produced by the first attempt and stored four checkpoints at 2 second intervals:
122+
123+
![](/assets/checkpoint_card.png)

0 commit comments

Comments
 (0)