Skip to content

Commit 1633285

Browse files
OleehyOzhangch9
andauthored
[refactor] Changes deepspeed + accelerate to FSDP (#33)
* [feat] Add DistributedLogger * [sampler] Add distributed packing sampler * [logger] Use filelock to prevent write conflicts * [trainer] Refactor for FSDP training * [tool] Add merge tool for dist checkpoint * [deps] Update dependencies * [docs] Update --------- Co-authored-by: Chenhui Zhang <[email protected]>
1 parent 864f143 commit 1633285

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1529
-1969
lines changed

docs/01-Intro.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@ slug: /
44

55
# Introduction
66

7-
CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [**CogView**](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [**CogVideoX**](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as **text-to-image (T2I)**, **text-to-video (T2V)**, and **image-to-video (I2V)**. Users must comply with legal and ethical guidelines to ensure responsible implementation.
7+
CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [CogView](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as text-to-image(T2I), text-to-video(T2V), and image-to-video(I2V). Users must comply with legal and ethical guidelines to ensure responsible implementation.
88

99
## Supported Models
1010

1111
Please refer to the [Model Card](./05-Model%20Card.mdx) for more details.
1212

1313
## Environment Testing
1414

15-
This repository has been tested in environments with `1×A100` and `8×A100` GPUs, using `CUDA 12.4, Python 3.10.16`.
16-
17-
- Cog series models typically do not support `FP16` precision (Only `CogVideoX-2B` support); GPUs like the `V100` cannot be fine-tuned properly (Will cause `loss=nan` for example). At a minimum, an `A100` or other GPUs supporting `BF16` precision should be used.
18-
- We have not yet systematically tested the minimum GPU memory requirements for each model. For `LORA(bs=1 with offload)`, a single `A100` GPU is sufficient. For `SFT`, our tests have passed in an `8×A100` environment.
15+
This repository has been tested in environments with 8×A100 GPUs, using CUDA 12.4, Python 3.10.16.

docs/04-Finetune/01-Prerequisites.mdx

Lines changed: 34 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,109 +3,69 @@
33

44
# Prerequisites
55

6-
Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM (GPU memory) requirements for different models under various configurations.
6+
Before starting fine-tuning, please ensure your machine meets the minimum hardware requirements listed in the tables below. The tables show the minimum VRAM requirements for different models under various configurations (test on 8xA100).
77

88
## CogVideo Series
99

1010
<table style={{ textAlign: "center" }}>
1111
<thead>
1212
<tr>
1313
<th style={{ textAlign: "center" }}>Model</th>
14-
<th style={{ textAlign: "center" }}>Training Type</th>
15-
<th style={{ textAlign: "center" }}>Distribution Strategy</th>
16-
<th style={{ textAlign: "center" }}>Training Resolution (FxHxW)</th>
14+
<th style={{ textAlign: "center" }}>Type</th>
15+
<th style={{ textAlign: "center" }}>Strategy</th>
16+
<th style={{ textAlign: "center" }}>Resolution <br /> (FxHxW)</th>
1717
<th style={{ textAlign: "center" }}>Requirement</th>
1818
</tr>
1919
</thead>
2020
<tbody>
2121
<tr>
22-
<td rowspan="6">cogvideox-t2v-2b</td>
22+
<td rowSpan="2">cogvideox-t2v-2b</td>
2323
<td>lora</td>
2424
<td>DDP</td>
2525
<td>49x480x720</td>
26-
<td>16GB VRAM</td>
26+
<td>1 GPU with <br /> 12GB VRAM</td>
2727
</tr>
2828
<tr>
29-
<td rowspan="5">sft</td>
29+
<td rowSpan="1">sft</td>
3030
<td>DDP</td>
3131
<td>49x480x720</td>
32-
<td>36GB VRAM</td>
32+
<td>1 GPU with <br /> 25GB VRAM</td>
3333
</tr>
3434
<tr>
35-
<td>1-GPU zero-2 + opt offload</td>
36-
<td>49x480x720</td>
37-
<td>17GB VRAM</td>
38-
</tr>
39-
<tr>
40-
<td>8-GPU zero-2</td>
41-
<td>49x480x720</td>
42-
<td>17GB VRAM</td>
43-
</tr>
44-
<tr>
45-
<td>8-GPU zero-3</td>
46-
<td>49x480x720</td>
47-
<td>19GB VRAM</td>
48-
</tr>
49-
<tr>
50-
<td>8-GPU zero-3 + opt and param offload</td>
51-
<td>49x480x720</td>
52-
<td>14GB VRAM</td>
53-
</tr>
54-
<tr>
55-
<td rowspan="5">cogvideox-\{t2v,i2v\}-5b</td>
35+
<td rowSpan="3">cogvideox-\{t2v,i2v\}-5b</td>
5636
<td>lora</td>
5737
<td>DDP</td>
5838
<td>49x480x720</td>
59-
<td>24GB VRAM</td>
60-
</tr>
61-
<tr>
62-
<td rowspan="4">sft</td>
63-
<td>1-GPU zero-2 + opt offload</td>
64-
<td>49x480x720</td>
65-
<td>42GB VRAM</td>
39+
<td>1 GPU with <br /> 24GB VRAM</td>
6640
</tr>
6741
<tr>
68-
<td>8-GPU zero-2</td>
42+
<td rowSpan="2">sft</td>
43+
<td>FSDP fullshard</td>
6944
<td>49x480x720</td>
70-
<td>42GB VRAM</td>
45+
<td>8 GPU with <br /> 20GB VRAM</td>
7146
</tr>
7247
<tr>
73-
<td>8-GPU zero-3</td>
48+
<td>FSDP fullshard + offload</td>
7449
<td>49x480x720</td>
75-
<td>43GB VRAM</td>
50+
<td>1 GPU with <br /> 16GB VRAM</td>
7651
</tr>
7752
<tr>
78-
<td>8-GPU zero-3 + opt and param offload</td>
79-
<td>49x480x720</td>
80-
<td>28GB VRAM</td>
81-
</tr>
82-
<tr>
83-
<td rowspan="5">cogvideox1.5-\{t2v,i2v\}-5b</td>
53+
<td rowSpan="3">cogvideox1.5-\{t2v,i2v\}-5b</td>
8454
<td>lora</td>
8555
<td>DDP</td>
8656
<td>81x768x1360</td>
87-
<td>35GB VRAM</td>
88-
</tr>
89-
<tr>
90-
<td rowspan="4">sft</td>
91-
<td>1-GPU zero-2 + opt offload</td>
92-
<td>81x768x1360</td>
93-
<td>56GB VRAM</td>
57+
<td>1 GPU with <br /> 32GB VRAM</td>
9458
</tr>
9559
<tr>
96-
<td>8-GPU zero-2</td>
60+
<td rowSpan="2">sft</td>
61+
<td>FSDP fullshard</td>
9762
<td>81x768x1360</td>
98-
<td>55GB VRAM</td>
63+
<td>8 GPUs with <br /> 31GB VRAM</td>
9964
</tr>
10065
<tr>
101-
<td>8-GPU zero-3</td>
66+
<td>FSDP fullshard + offload</td>
10267
<td>81x768x1360</td>
103-
<td>55GB VRAM</td>
104-
</tr>
105-
<tr>
106-
<td>8-GPU zero-3 + opt and param offload</td>
107-
<td>81x768x1360</td>
108-
<td>40GB VRAM</td>
68+
<td>8 GPUs with <br /> 27GB VRAM</td>
10969
</tr>
11070
</tbody>
11171
</table>
@@ -116,46 +76,36 @@ Before starting fine-tuning, please ensure your machine meets the minimum hardwa
11676
<thead>
11777
<tr>
11878
<th style={{ textAlign: "center" }}>Model</th>
119-
<th style={{ textAlign: "center" }}>Training Type</th>
120-
<th style={{ textAlign: "center" }}>Distribution Strategy</th>
121-
<th style={{ textAlign: "center" }}>Training Resolution (HxW)</th>
79+
<th style={{ textAlign: "center" }}>Type</th>
80+
<th style={{ textAlign: "center" }}>Strategy</th>
81+
<th style={{ textAlign: "center" }}>Resolution <br /> (HxW)</th>
12282
<th style={{ textAlign: "center" }}>Requirement</th>
12383
</tr>
12484
</thead>
12585
<tbody>
12686
<tr>
127-
<td rowspan="6">CogView4-6B</td>
128-
<td>qlora + param offload <br />(`--low_vram`)</td>
87+
<td rowSpan="4">CogView4-6B</td>
88+
<td>qlora + offload <br />(enable --low_vram)</td>
12989
<td>DDP</td>
13090
<td>1024x1024</td>
131-
<td>9GB VRAM</td>
91+
<td>1 GPU with <br /> 9GB VRAM</td>
13292
</tr>
13393
<tr>
13494
<td>lora</td>
13595
<td>DDP</td>
13696
<td>1024x1024</td>
137-
<td>30GB VRAM</td>
138-
</tr>
139-
<tr>
140-
<td rowspan="4">sft</td>
141-
<td>1-GPU zero-2 + opt offload</td>
142-
<td>1024x1024</td>
143-
<td>42GB VRAM</td>
144-
</tr>
145-
<tr>
146-
<td>8-GPU zero-2</td>
147-
<td>1024x1024</td>
148-
<td>50GB VRAM</td>
97+
<td>1 GPU with <br /> 20GB VRAM</td>
14998
</tr>
15099
<tr>
151-
<td>8-GPU zero-3</td>
100+
<td rowSpan="2">sft</td>
101+
<td>FSDP fullshard</td>
152102
<td>1024x1024</td>
153-
<td>47GB VRAM</td>
103+
<td>8 GPUs with <br /> 28GB VRAM</td>
154104
</tr>
155105
<tr>
156-
<td>8-GPU zero-3 + opt and param offload</td>
106+
<td>FSDP fullshard + offload</td>
157107
<td>1024x1024</td>
158-
<td>28GB VRAM</td>
108+
<td>8 GPUs with <br /> 22GB VRAM</td>
159109
</tr>
160110
</tbody>
161111
</table>

docs/04-Finetune/02-Quick Start.md

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,35 @@ We recommend that you read the corresponding [model card](../05-Model%20Card.mdx
2727
:::
2828

2929
1. Navigate to the `CogKit/` directory after cloning the repository
30+
3031
```bash
3132
cd CogKit/
3233
```
3334

34-
2. Choose the appropriate training script from the `quickstart/scripts` directory based on your task type and distribution strategy. For example, `train_ddp_t2i.sh` corresponds to DDP strategy + text-to-image task
35-
36-
3. Review and adjust the parameters in the selected training script (e.g., `--data_root`, `--output_dir`, etc.)
35+
2. Choose the appropriate subdirectory from the `quickstart/scripts` based on your task type and distribution strategy. For example, `t2i` corresponds to text-to-image task
3736

38-
4. [Optional] If you are using ZeRO strategy, refer to `quickstart/configs/accelerate_config.yaml` to confirm your ZeRO config file and number of GPUs.
37+
3. Review and adjust the parameters in `config.yaml` in the selected training directory
3938

40-
5. Run the script, for example:
39+
4. Run the script in the selected directory:
4140

4241
```bash
43-
cd quickstart/scripts
44-
bash train_ddp_t2i.sh
42+
bash start_train.sh
4543
```
4644

4745
## Load Fine-tuned Model
4846

49-
### LoRA
50-
51-
After fine-tuning with LoRA, you can load your trained weights during inference using the `--lora_model_id_or_path` option or parameter. For more details, please refer to the inference guide.
47+
### Merge Checkpoint
5248

53-
### ZeRO
54-
55-
After fine-tuning with ZeRO strategy, you need to use the `zero_to_fp32.py` script provided in the `quickstart/tools/converters` directory to convert the ZeRO checkpoint weights into Diffusers format. For example:
49+
After fine-tuning, you need to use the `merge.py` script to merge the distributed checkpoint weights into a single checkpoint (**except for QLoRA fine-tuning**).
50+
The script can be found in the `quickstart/tools/converters` directory.
51+
For example:
5652

5753
```bash
5854
cd quickstart/tools/converters
59-
python zero2diffusers.py checkpoint_dir/ output_dir/ --bfloat16
55+
python merge.py --checkpoint_dir ckpt/ --output_dir output_dir/
56+
# Add --lora option if you are using LoRA fine-tuning
6057
```
6158

62-
During inference, pass the `output_dir/` to the `--transformer_path` option or parameter. For more details, please refer to the inference guide.
59+
### Load Checkpoint
60+
61+
You can pass the `output_dir` to the `--lora_model_id_or_path` option if you are using LoRA fine-tuning, or to the `--transformer_path` option if you are using FSDP fine-tuning. For more details, please refer to the inference guide.

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ dependencies = [
1818
"pydantic~=2.10",
1919
"sentencepiece==0.2.0",
2020
"transformers~=4.49",
21-
"wandb~=0.19.8",
2221
"fastapi[standard]~=0.115.11",
2322
"fastapi_cli~=0.0.7",
2423
"openai~=1.67",
@@ -31,10 +30,10 @@ dependencies = [
3130
[project.optional-dependencies]
3231
finetune = [
3332
"datasets~=3.4",
34-
"deepspeed~=0.16.4",
33+
"wandb~=0.19.8",
3534
"av~=14.2.0",
3635
"bitsandbytes~=0.45.4",
37-
"tensorboard~=2.19",
36+
"pyyaml>=6.0.2",
3837
]
3938

4039
[project.urls]

quickstart/configs/accelerate_config.yaml

Lines changed: 0 additions & 26 deletions
This file was deleted.

quickstart/configs/zero/zero2.yaml

Lines changed: 0 additions & 38 deletions
This file was deleted.

quickstart/configs/zero/zero2_offload.yaml

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)