|
| 1 | +<div align="center"> |
| 2 | + |
| 3 | +# 序列并行:训练极长序列大模型的系统优化 |
| 4 | + |
| 5 | +</div> |
| 6 | + |
| 7 | +XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509),并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。 |
| 8 | + |
| 9 | +## 简介 |
| 10 | + |
| 11 | +从生成性AI到科研模型,长序列训练正在变得非常重要。 |
| 12 | + |
| 13 | +在生成性AI领域,会话式AI、长文档摘要、代码库理解和例如 Sora 这种视频生成任务都需要在空间和时间层面对长上下文进行推理。 |
| 14 | + |
| 15 | +对于科学AI来说,长序列同样至关重要,它为更好地理解结构生物学、医疗保健、气候和天气预测以及大分子模拟打开了大门。 |
| 16 | + |
| 17 | +然而,尽管序列长度的重要性不断增长,XTuner 现有的显存优化策略(如 zero 系列),却不足以解决大模型、长序列训练问题。 |
| 18 | + |
| 19 | +同时,受限于通信效率,现有的许多序列并行方法也不够高效。 |
| 20 | + |
| 21 | +另外,现有的序列并行方法普遍存在较多的代码侵入式修改,易用性和维护性都要大打折扣。同时也不满足 XTuner 基于 transformers 算法库或 Huggingface Hub 上的开源模型直接进行训练的要求。 |
| 22 | + |
| 23 | +<div align="center"> |
| 24 | + <img src="https://github.com/InternLM/xtuner/assets/41630003/0b791458-40bd-4dc6-aaf5-ff891fcc112a" width="1000"/> |
| 25 | + <br /><br /> |
| 26 | +</div> |
| 27 | + |
| 28 | +为了解决上述长序列训练带来的问题,XTuner 采用了一种简单、易用且高效的序列并行算法。由于 Transformer 结构较为规整,除 attention 计算外,其他计算过程中 token 之间不会互相影响(即每个 token 的计算是独立的),这一条件为序列并行提供了有利条件。上图展示了序列并行的核心设计。设由 P 个 GPUs 共同计算一个长度为 N 的长序列,在 Attention 计算的第一阶段,长度为 N / P 的子序列会通过线性层投影为 Query、Key、Value。接下来, QKV Tensor 会在参与序列并行计算的多个 GPUs 之间通过高度优化的 all-to-all 通信算子汇聚,得到序列长度为 N ,但更少注意力头的子序列。注意力计算后,通过另一个 all-to-all 通信算子将其转换为长度为 N / P 的子序列,进行后续计算。 |
| 29 | + |
| 30 | +总体而言,XTuner 的序列并行算法具有以下关键特性: |
| 31 | + |
| 32 | +* 支持全量训练**超过百万个token**的序列 |
| 33 | +* 支持百 B 级模型训练:XTuner 的序列并行不仅支持长序列训练,还可结合 zero3 显存优化策略训练大尺寸模型 |
| 34 | +* 完全通用的序列并行 **API 抽象** |
| 35 | + |
| 36 | +## 使用 XTuner 进行序列并行训练 |
| 37 | + |
| 38 | +### Step 1 修改 config 文件 |
| 39 | + |
| 40 | +1. 在 config 中修改 `sequence_parallel_size` 字段即可调整 $sequence\\_parallel\\_world\\_size$ 。 |
| 41 | +2. 同时若想保证与不使用序列并行的训练效果类似,需要同步增大梯度累积的数值为原来的 $sequence\\_parallel\\_world\\_size$ 倍,因为在使用序列并行训练时, $data\\_parallel\\_world\\_size$ 降为了原来的 $\frac{1}{sequence\\_parallel\\_world\\_size}$。 |
| 42 | +3. 替换 DefaultSampler 为支持序列并行的 SequenceParallelSampler。 |
| 43 | + |
| 44 | +**注:需要保证所使用的 GPU 总数可以被 `sequence_parallel_size` 整除。** |
| 45 | + |
| 46 | +```diff |
| 47 | ++ from xtuner.parallel.sequence import SequenceParallelSampler |
| 48 | + |
| 49 | +- sequence_parallel_size = 1 |
| 50 | ++ sequence_parallel_size = 4 # take `sequence_parallel_size = 4`` as an example |
| 51 | + |
| 52 | +- accumulative_counts = 1 |
| 53 | ++ accumulative_counts = 4 # accumulative_counts = accumulative_counts * sequence_parallel_size |
| 54 | + |
| 55 | +####################################################################### |
| 56 | +# PART 3 Dataset & Dataloader # |
| 57 | +####################################################################### |
| 58 | +train_dataloader = dict( |
| 59 | +- sampler=dict(type=DefaultSampler, shuffle=True), |
| 60 | ++ sampler=dict(type=SequenceParallelSampler, seed=1024, shuffle=True), |
| 61 | + ...) |
| 62 | +``` |
| 63 | + |
| 64 | +另外,若需要进一步拓展模型的长文本处理能力,需要进一步修改 config 中的 `max_position_embeddings` 字段。例如需要将模型的上下文长度拓展为 64K 时,可进行如下修改: |
| 65 | + |
| 66 | +```diff |
| 67 | ++ max_position_embeddings = 65536 |
| 68 | + |
| 69 | +####################################################################### |
| 70 | +# PART 2 Model & Tokenizer # |
| 71 | +####################################################################### |
| 72 | +model = dict( |
| 73 | + type=SupervisedFinetune, |
| 74 | ++ max_position_embeddings = max_position_embeddings, |
| 75 | + ...) |
| 76 | +``` |
| 77 | + |
| 78 | +### Step 2 开始训练 |
| 79 | + |
| 80 | +需要使用 DeepSpeed 进行训练: |
| 81 | + |
| 82 | +```bash |
| 83 | +(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train ${CONFIG_PATH} --deepspeed deepspeed_zero2 |
| 84 | +(SLURM) srun ${SRUN_ARGS} xtuner train ${CONFIG_PATH} --launcher slurm --deepspeed deepspeed_zero2 |
| 85 | +``` |
| 86 | + |
| 87 | +- ${CONFIG_PATH} 为 Step 1 中修改得到的 config 文件路径 |
| 88 | +- 可根据实际情况选择使用不同的 zero 策略 |
| 89 | + |
| 90 | +## 序列并行 API 抽象 |
| 91 | + |
| 92 | +为了提升算法的可迁移性,XTuner 中抽象出了序列并行所必须的五个 API 接口: |
| 93 | +- 序列并行分布式环境初始化 (init_sequence_parallel) |
| 94 | +- 适配序列并行的 Data Sampler (SequenceParallelSampler) |
| 95 | +- 数据 Pad 与切分 (pad_for_sequence_parallel, split_for_sequence_parallel) |
| 96 | +- 适配序列并行的 Attention (dispatch_modules) |
| 97 | +- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss) |
| 98 | + |
| 99 | +### 序列并行分布式环境初始化 |
| 100 | + |
| 101 | +由于序列并行算法会将长序列切分为 $sequence\\_parallel\\_world\\_size$ 块,并将每个子序列分发给对应的 GPU 独立进行计算。因此需要在训练开始前初始化序列并行分布式环境,以指定哪几块 GPU 共同负责一个长序列输入的计算。 |
| 102 | + |
| 103 | +一个 $sequence\\_parallel\\_world\\_size = 4$ 的示例如下: |
| 104 | + |
| 105 | +```python |
| 106 | +# We have to initialize the distributed training environment first. |
| 107 | +# Here is an example when training on slurm scheduler |
| 108 | +# from xtuner.parallel.sequence import init_dist |
| 109 | +# init_dist('slurm', 'nccl', init_backend='deepspeed') |
| 110 | +from xtuner.parallel.sequence import init_sequence_parallel |
| 111 | +sequence_parallel_world_size = 4 |
| 112 | +init_sequence_parallel(sequence_parallel_world_size) |
| 113 | +``` |
| 114 | + |
| 115 | +上述过程在 xtuner/engine/_strategy/deepspeed.py 中实现。 |
| 116 | + |
| 117 | +### Data Sampler 适配序列并行 |
| 118 | + |
| 119 | +在使用序列并行后,Dataloader 的采样策略需要进一步调整。例如当 $sequence\\_parallel\\_world\\_size = 4$ 时,4 块 GPU 从 Dataloader 拿到的数据需要是完全一样的。 |
| 120 | + |
| 121 | +在构建 Dataloader 时搭配 XTuner 中提供的 SequenceParallelSampler 使用即可: |
| 122 | + |
| 123 | +```python |
| 124 | +from xtuner.parallel.sequence import SequenceParallelSampler |
| 125 | +dataloader = DataLoader( |
| 126 | + train_dataset, sampler=SequenceParallelSampler(train_dataset), |
| 127 | + **other_dataloader_params) |
| 128 | +``` |
| 129 | + |
| 130 | +### 数据 Pad 与切分 |
| 131 | + |
| 132 | +由于每条训练数据的长度可能不尽相同,我们需要将数据进行 Pad 以使得序列长度可以被 $sequence\\_parallel\\_world\\_size$ 整除,这样一条长数据才能被均等地分发给不同的 GPU 上。 |
| 133 | + |
| 134 | +训练过程中需要被 Pad 的 Tensor 往往有 input_ids, labels, position_ids, attention_mask 四个,pad 的过程可以通过以下方式实现: |
| 135 | + |
| 136 | +```python |
| 137 | +from xtuner.parallel.sequence import pad_for_sequence_parallel |
| 138 | +input_ids, labels, position_ids, attention_mask = pad_for_sequence_parallel( |
| 139 | + input_ids, labels, position_ids, attention_mask) |
| 140 | +``` |
| 141 | + |
| 142 | +如果训练过程用不到 attention_mask,那么可以: |
| 143 | + |
| 144 | +```python |
| 145 | +input_ids, labels, position_ids, _ = pad_for_sequence_parallel( |
| 146 | + input_ids, labels, position_ids) |
| 147 | +``` |
| 148 | + |
| 149 | +Pad 后,我们需要对长序列均等切分: |
| 150 | + |
| 151 | +```python |
| 152 | +from xtuner.parallel.sequence import split_for_sequence_parallel |
| 153 | +# attention mask should not be split |
| 154 | +input_ids, labels, position_ids = split_for_sequence_parallel( |
| 155 | + input_ids, labels, position_ids) |
| 156 | +``` |
| 157 | + |
| 158 | +以上两步在 xtuner/dataset/collate_fns/defalut_collate_fn.py 中实现。 |
| 159 | + |
| 160 | +### Attention 适配序列并行 |
| 161 | + |
| 162 | +在 Attention 的计算过程中,序列中的不同 token 是不能独立运算的,但不同的 attention head 之间的计算却是独立的。因此,如[第一节](#简介)所述,需要在计算 Attention 前后(即 qkv_proj 后和 o_proj 前)分别插入一个 *all-to-all* 操作。 |
| 163 | + |
| 164 | +XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计算方式: |
| 165 | + |
| 166 | +```python |
| 167 | +from xtuner.model.modules import dispatch_modules |
| 168 | +model: AutoModelForCausalLM |
| 169 | +dispatch_modules(model) |
| 170 | +``` |
| 171 | + |
| 172 | +上述过程在 xtuner/model/sft.py 中实现。 |
| 173 | + |
| 174 | +### Reduce Loss 以正确打印训练损失 |
| 175 | + |
| 176 | +这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。 |
| 177 | + |
| 178 | +```python |
| 179 | +from xtuner.parallel.sequence import reduce_sequence_parallel_loss |
| 180 | +outputs = llm(input_ids=input_ids, labels=labels, **kwargs) |
| 181 | +num_tokens_per_rank = (labels != -100).sum() |
| 182 | +# Suppose sequence parallel world size equals to 4, |
| 183 | +# losses on rank0, rank1, rank2, rank3 are different. |
| 184 | +loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank) |
| 185 | +# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same. |
| 186 | +``` |
| 187 | + |
| 188 | +上述过程在 xtuner/model/sft.py 中实现。 |
0 commit comments