|
| 1 | +<div align="center"> |
| 2 | + |
| 3 | +# Automatic Tensor Parallel (AutoTP) Training of Hugging Face models |
| 4 | + |
| 5 | +</div> |
| 6 | + |
| 7 | + |
| 8 | +# Introduction |
| 9 | + |
| 10 | +Tensor parallelism (TP) is an important memory optimization for training large-scale deep learning models. Despite the popularity of training Hugging Face (HF) [models](https://huggingface.co/models), the model scaling options for **[HF trainer](https://huggingface.co/docs/transformers/main_classes/trainer)** was previously limited to sharded data parallelism through [ZeRO](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)/[FSDP](https://huggingface.co/docs/accelerate/usage_guides/fsdp). While ZeRO3 offers superior memory efficiency, it incurs significant communication costs. ZeRO (1/2) has lower communication overhead, but in the case of very large models, it cannot be used directly due to memory limitations. Therefore, combining TP with ZeRO (1/2) offers more balanced options for memory and performance. Moreover, through TP, we can alleviate the batch scaling limitations imposed by ZeRO/FSDP. |
| 11 | + |
| 12 | +We are pleased to announce that DeepSpeed now provides native automatic tensor parallel training for Hugging Face (HF) transformers. This new feature builds on DeepSpeed's [AutoTP](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) mechanism, which was previously restricted to inference. AutoTP training can be combined with ZeRO to unlock unprecented efficiency benefits for HF model post-training, including: |
| 13 | + |
| 14 | +**1**. Model scaling with lower communication costs than FSDP/ZeRO3 (e.g., use AutoTP + ZeRO1 to achieve ZeRO3 memory savings). |
| 15 | + |
| 16 | +**2**. Batch size scaling for faster training and increased throughput. |
| 17 | + |
| 18 | +**3**. Context length scaling to enable new application scenarios. |
| 19 | + |
| 20 | +We have integrated AutoTP training with ZeRO1 & ZeRO2, with ZeRO3 integration on the way. AutoTP training is available in DeepSpeed versions >= 0.16.4 |
| 21 | + |
| 22 | +# Batch Scaling with AutoTP Training + ZeRO |
| 23 | +The following is a batch scaling experiment of Llama3 8B training conducted on [Gaudi2 Accelerator](https://www.intel.com/content/www/us/en/products/details/processors/ai-accelerators/gaudi.html). |
| 24 | + |
| 25 | + |
| 26 | +<div align="center"> |
| 27 | + |
| 28 | +<img src="media/batchscale.png"> |
| 29 | + |
| 30 | + |
| 31 | +*Figure 1. Batch scaling experiment on Gaudi2, showing throughput performance improvements from 2 to 4 cards by combining AutoTP and ZeRO. The used mbs is the max possible value with the given config. A higher speedup indicates better performance.* |
| 32 | + |
| 33 | +</div> |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | +<div align="center"> |
| 38 | + |
| 39 | +<img src="media/flowchart.png"> |
| 40 | + |
| 41 | + |
| 42 | +*Figure 2. Model training with AutoTP + ZeRO* |
| 43 | + |
| 44 | +</div> |
| 45 | + |
| 46 | + |
| 47 | +Figure 2 illustrates the basic flowchart, The division of TP and ZeRO is implemented through the AutoTP parser and ZeRO Wrapper in [Accelerate](https://github.com/huggingface/accelerate.git). Besides, The TP-based dataloader and save mechanism are both supported in DeepSpeed and Accelerate. |
| 48 | + |
| 49 | +# Usage |
| 50 | + |
| 51 | +Although we evaluated AutoTP training with Llama2 & Llama3 models in this blog, we expect compatibility with other Hugging Face models, especially [those](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) previously validated with AutoTP inference. Please upgrade accelerate and transformers to the master branch. We will add their minimum version once they have release tag. |
| 52 | + |
| 53 | + |
| 54 | + **Enable TP training** |
| 55 | + |
| 56 | +Similar to ZeRO, AutoTP training is enabled using the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/) by specifying ```[tensor_parallel][autotp_size]```. |
| 57 | +``` |
| 58 | + "ZeRO_optimization": { |
| 59 | + "stage": 1, |
| 60 | + "gather_16bit_weights_on_model_save": true, |
| 61 | + ... |
| 62 | + }, |
| 63 | + "tensor_parallel":{ |
| 64 | + "autotp_size": 4 |
| 65 | + }, |
| 66 | +``` |
| 67 | + |
| 68 | +The parallel configuration follows this logic: |
| 69 | + |
| 70 | + |
| 71 | +``` |
| 72 | +tp_size = auto_tp_size |
| 73 | +dp_size = num_gpus / tp_size |
| 74 | +``` |
| 75 | + |
| 76 | +Note that the global_batch_size (gbs) changes with different TP settings: |
| 77 | +``` |
| 78 | +gbs (only dp) = per_device_batch_size * n_gpus * gradient_accumulation_steps |
| 79 | +
|
| 80 | +gbs (dp with tp) = per_device_batch_size * n_gpus / tp_size * gradient_accumulation_steps |
| 81 | +``` |
| 82 | + |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | + |
| 88 | + |
| 89 | + **Save Model** |
| 90 | + |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | +Saving checkpoints and model files is fully compatible with HF transformers. The [trainer.save_model()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_model) method saves the original model. Ensure ```gather_16bit_weights_on_model_save``` is set to ```true```in the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/). |
| 95 | +```gather_16bit_weights_on_model_save=true in config. |
| 96 | + "ZeRO_optimization": { |
| 97 | + ... |
| 98 | + "gather_16bit_weights_on_model_save": true, |
| 99 | + }, |
| 100 | +``` |
| 101 | + |
| 102 | +``` |
| 103 | +trainer.save_model(your_saved_path) |
| 104 | +``` |
| 105 | +Models saved this way can be directly used for HF format inference without intermediate transformations. |
| 106 | + |
| 107 | + |
| 108 | + |
| 109 | + **Saving Checkpoints and Resuming** |
| 110 | + |
| 111 | + |
| 112 | + |
| 113 | +Saving Checkpoints remains compatible with HF transformers. Use [trainer.save_state()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_state) or set the save interval for automatic saving, which can be used to resume training. |
| 114 | +``` |
| 115 | +trainer.train(resume_from_checkpoint="your_saved_path/checkpoint-1200") |
| 116 | +) |
| 117 | +``` |
| 118 | + |
| 119 | +# Example |
| 120 | +We validated AutoTP training using supervised finetune training (SFT) task: [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca). The original benchmark model used in this project is Llama2-7B. |
| 121 | + |
| 122 | + |
| 123 | + |
| 124 | +**Training Loss curve** |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | +The following loss curves depict SFT training, where gbs is uniformly set to 32, and other configurations match the default experiment settings from ([stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca)). The loss curves are largely consistent across the following setups: |
| 129 | + |
| 130 | + - ZeRO3 |
| 131 | + - TP + disable ZeRO |
| 132 | + - ZeRO1 and ZeRO1 + AutoTP |
| 133 | + - ZeRO2 and ZeRO2 + AutoTP |
| 134 | + |
| 135 | + |
| 136 | + |
| 137 | + |
| 138 | + |
| 139 | +<div align="center"> |
| 140 | + |
| 141 | + |
| 142 | +<img src="media/zero3.png"> |
| 143 | + |
| 144 | +*Figure 3. Loss curve of ZeRO3 stage training (gbs=32, dp=8)* |
| 145 | + |
| 146 | +</div> |
| 147 | +<div align="center"> |
| 148 | + |
| 149 | +<img src="media/tp8.png"> |
| 150 | + |
| 151 | +*Figure 4. Loss curve of AutoTP training (gbs=32, tp=8)* |
| 152 | +</div> |
| 153 | + |
| 154 | +<div align="center"> |
| 155 | + |
| 156 | +<img src="media/tpzero1.png"> |
| 157 | + |
| 158 | +*Figure 5. Loss curve of AutoTP + ZeRO1 training (gbs=32, dp=2, tp=4)* |
| 159 | +</div> |
| 160 | + |
| 161 | + |
| 162 | +<div align="center"> |
| 163 | + |
| 164 | +<img src="media/tpzero2.png"> |
| 165 | + |
| 166 | +*Figure 6. Loss curve of AutoTP + ZeRO2 training (gbs=32, dp=2, tp=4)* |
| 167 | + |
| 168 | + |
| 169 | +</div> |
| 170 | + |
| 171 | + |
| 172 | + **Resuming Training** |
| 173 | + |
| 174 | + |
| 175 | + We tested recovery training curves from step 1200 in AutoTP + ZeRO1 and AutoTP + ZeRO2, which align with the original training curves. |
| 176 | + |
| 177 | +<div align="center"> |
| 178 | + |
| 179 | +<img src="media/zero1tpload.png"> |
| 180 | + |
| 181 | +*Figure 7. AutoTP + ZeRO1 resuming training* |
| 182 | + |
| 183 | +<img src="media/zero2tpload.png"> |
| 184 | + |
| 185 | +*Figure 8. AutoTP + ZeRO2 resuming training* |
| 186 | + |
| 187 | +</div> |
| 188 | + |
| 189 | + |
| 190 | + |
| 191 | + **Model Evaluation** |
| 192 | + |
| 193 | + |
| 194 | + We conducted inference evaluations for the [MMLU task](https://github.com/EleutherAI/lm-evaluation-harness). |
| 195 | + In MMLU, the scores for AutoTP + ZeRO1 and ZeRO1, as well as AutoTP + ZeRO2 and ZeRO2, are consistent, showing a fixed improvement over the pre-training model before SFT. |
| 196 | + |
| 197 | + |
| 198 | +<div align="center"> |
| 199 | + |
| 200 | + |
| 201 | +| Groups | Version | Filter | n-shot | Metric | Model before SFT | ZeRO1 DP8 training | ZeRO1 TP4 DP2 training | ZeRO2 DP8 training | ZeRO2 TP4DP2 training | |
| 202 | +|--------|---------|--------|--------|--------|-----------------------|--------------------|------------------------|--------------------|------------------------| |
| 203 | +| mmlu | 2 | none | | acc | 0.4185 ± 0.0041 | 0.4472 ± 0.0041 | 0.4444 ± 0.0041 | 0.4543 ± 0.0041 | 0.4529 ± 0.0041 | |
| 204 | +| - humanities | 2 | none | | acc | 0.3979 ± 0.0069 | 0.4185 ± 0.0070 | 0.4145 ± 0.0069 | 0.4274 ± 0.0070 | 0.4272 ± 0.0070 | |
| 205 | +| - other | 2 | none | | acc | 0.4712 ± 0.0089 | 0.5249 ± 0.0087 | 0.5182 ± 0.0088 | 0.5282 ± 0.0087 | 0.5269 ± 0.0087 | |
| 206 | +| - social sciences | 2 | none | | acc | 0.4742 ± 0.0089 | 0.5070 ± 0.0089 | 0.5083 ± 0.0088 | 0.5151 ± 0.0088 | 0.5115 ± 0.0089 | |
| 207 | +| - stem | 2 | none | | acc | 0.3428 ± 0.0084 | 0.3549 ± 0.0084 | 0.3539 ± 0.0084 | 0.3622 ± 0.0084 | 0.3609 ± 0.0084 | |
| 208 | + |
| 209 | +*Table 1. MMLU score with Llama2-7B inference* |
| 210 | + |
| 211 | +</div> |
| 212 | + |
| 213 | + |
| 214 | + |
| 215 | + |
| 216 | + |
| 217 | +# Miscellaneous |
| 218 | + |
| 219 | +If users define their own dataloader, please ensure data consistency within ```deepspeed.utils.get_tensor_model_parallel_group()```. DeepSpeed provides basic validation functions to assist with this. |
| 220 | + |
| 221 | +Furthermore, if users are not using transformers library, you can replace the ```TensorParallel_Layer``` layer and its subclasses as needed. See ```prepare_tp_model``` function in ```unit/model_parallelism/test_autotp_training.py```. Users can also define different shard and gather for subclasses of ```TensorParallel_Layer.``` |
| 222 | + |
| 223 | + |
| 224 | + |
| 225 | + |
| 226 | + |
| 227 | +# Ongoing Work |
| 228 | +- **Optimization**: Communication/Activation optimization. |
| 229 | +- **Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing, |
| 230 | + |
| 231 | + |
| 232 | +Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending. |
| 233 | + |
| 234 | +Welcome bug reports, enhancement, and additional model training examples. |
| 235 | + |
| 236 | + |
| 237 | +# Contributors |
| 238 | +This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft. |
0 commit comments