|
| 1 | +# 8xH100 训练 Qwen3-30B-A3B |
| 2 | + |
| 3 | +## 环境准备 |
| 4 | + |
| 5 | +搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为 |
| 6 | +Qwen3-next-80B-A3B-Instruct 即可。 |
| 7 | + |
| 8 | +可以用如下完整方法把 huggingface checkpoint 转化为 torch_dist 格式: |
| 9 | + |
| 10 | +```bash |
| 11 | +export BASE_FOLDER=./models/ |
| 12 | +# 下载模型权重 (Qwen3-Next-80B-A3B-Thinking) |
| 13 | +hf download Qwen/Qwen3-Next-80B-A3B-Thinking --local-dir ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking |
| 14 | +``` |
| 15 | + |
| 16 | +```shell |
| 17 | +cd slime/ |
| 18 | +pip install -e . |
| 19 | + |
| 20 | +# (for acceleration) |
| 21 | +cd .. # and find a proper folder |
| 22 | +git clone https://github.com/fla-org/flash-linear-attention |
| 23 | +cd flash-linear-attention |
| 24 | +git checkout 9714c595 |
| 25 | +pip install -e . |
| 26 | + |
| 27 | +wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl |
| 28 | +pip install ./causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl |
| 29 | +``` |
| 30 | + |
| 31 | +## [Optional] Fix a bug in triton compilation on Blackwell (sm100) |
| 32 | + |
| 33 | +see discussion here https://github.com/triton-lang/triton/issues/8695 |
| 34 | +and https://github.com/fla-org/flash-linear-attention/issues/638 |
| 35 | + |
| 36 | +We need to apply a patch to fix the bug. |
| 37 | +Go to the flash-linear-attention folder you just installed, and apply the following patch: |
| 38 | + |
| 39 | +```diff |
| 40 | +diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py |
| 41 | +index c5119dcf..838f5e4e 100644 |
| 42 | +--- a/fla/ops/gated_delta_rule/wy_fast.py |
| 43 | ++++ b/fla/ops/gated_delta_rule/wy_fast.py |
| 44 | +@@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel( |
| 45 | + b_A += tl.dot(b_kb, tl.trans(b_k)) |
| 46 | + b_dkb = tl.dot(b_dA, b_k) |
| 47 | + b_db += tl.sum(b_dkb * b_k, 1) |
| 48 | +- b_dk += tl.dot(tl.trans(b_dA), b_kb) |
| 49 | ++ b_dk += tl.inline_asm_elementwise( |
| 50 | ++ asm="mov.f32 $0, $1;", |
| 51 | ++ constraints="=r,r", |
| 52 | ++ args=[tl.dot(tl.trans(b_dA), b_kb)], |
| 53 | ++ dtype=tl.float32, |
| 54 | ++ is_pure=True, |
| 55 | ++ pack=1, |
| 56 | ++ ) |
| 57 | + b_dk += b_dkb * b_b[:, None] |
| 58 | + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) |
| 59 | + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) |
| 60 | + |
| 61 | +``` |
| 62 | + |
| 63 | +save it as `patch.diff` (Please remember to copy the last empty line to the file!) and do `git apply patch.diff` |
| 64 | + |
| 65 | +## 执行训练 (Megatron) |
| 66 | + |
| 67 | +**当前暂不支持Blackwell** |
| 68 | + |
| 69 | +转换模型权重: |
| 70 | + |
| 71 | +```bash |
| 72 | +source scripts/models/qwen3-next-80B-A3B.sh |
| 73 | +PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \ |
| 74 | + tools/convert_hf_to_torch_dist.py \ |
| 75 | + ${MODEL_ARGS[@]} \ |
| 76 | + --hf-checkpoint /root/Qwen3-Next-80B-A3B-Thinking/ \ |
| 77 | + --save /root/Qwen3-Next-80B-A3B-Thinking_torch_dist/ |
| 78 | +``` |
| 79 | + |
| 80 | +单机8卡 |
| 81 | + |
| 82 | +```bash |
| 83 | +cd /root/slime |
| 84 | +export BASE_FOLDER=/root |
| 85 | +export MASTER_ADDR=127.0.0.1 |
| 86 | +bash scripts/run-qwen3-next-80B-A3B-8gpus.sh |
| 87 | +``` |
| 88 | +如果显存不够,考虑disable `--accumulate-allreduce-grads-in-fp32`,enable `--grad-reduce-in-bf16` |
| 89 | + |
| 90 | + |
| 91 | +多机(4x8) |
| 92 | + |
| 93 | +```bash |
| 94 | +cd /root/slime |
| 95 | +export BASE_FOLDER=/root |
| 96 | +export MASTER_ADDR=your_master_addr |
| 97 | +bash scripts/run-qwen3-next-80B-A3B.sh |
| 98 | +``` |
| 99 | + |
| 100 | +## 执行训练 (FSDP) |
| 101 | + |
| 102 | +```bash |
| 103 | +export BASE_FOLDER=./models/ |
| 104 | +export MASTER_ADDR=127.0.0.1 |
| 105 | + |
| 106 | +bash scripts/run-qwen3-next-80B-A3B-fsdp.sh |
| 107 | +``` |
| 108 | + |
0 commit comments