Skip to content

Commit 4ec555b

Browse files
authored
Restore single-node instructions to run GRPO (#549)
1 parent 8000dd2 commit 4ec555b

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

README.md

+32-8
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,31 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
160160
```
161161

162162
### GRPO
163-
We use TRL's new distributed vLLM server and GRPOTraining in order to scale to larger >7B models. We provide an example slurm script:
163+
164+
We use TRL's [vLLM backend](https://huggingface.co/docs/trl/speeding_up_training?vllm+examples=GRPO#vllm-for-fast-generation-in-online-methods) to scale training to large models across multiple nodes. For single-node training of smol models across 8 GPUs, first spin up the vLLM server to run on e.g. 1 GPU as follows:
165+
164166
```shell
165-
sbatch --job-name=trl-Qwen2.5-Math-7B-config_simple_rl --nodes=2 slurm/train.slurm Qwen2.5-Math-7B grpo config_simple_rl zero3
167+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
166168
```
167169

168-
You will need to adapt the `slurm/train.slurm` script to match your cluster.
170+
Once the server is up, run training on the remaining GPUs as follows:
171+
172+
```shell
173+
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \
174+
accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7 \
175+
src/open_r1/grpo.py --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
176+
```
177+
178+
> [!WARNING]
179+
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
180+
181+
For multi-node training, we provide an example Slurm script:
169182

170-
Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.
183+
```shell
184+
sbatch --nodes=2 slurm/train.slurm Qwen2.5-Math-7B grpo config_simple_rl zero3
185+
```
186+
187+
You will need to adapt the `slurm/train.slurm` script to match your cluster.
171188

172189
#### 👨‍💻 Training with a code interpreter
173190

@@ -198,12 +215,18 @@ Then make sure your dataset contains a `verification_info` column with the follo
198215
}
199216
```
200217

201-
For example, to train a smol model on Python problems, run:
218+
For example, to train a smol model on Python problems, start the vLLM server:
202219

203220
```shell
204-
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
205-
--num_processes=7 src/open_r1/grpo.py \
206-
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
221+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-1.5B-Instruct
222+
```
223+
224+
Then run training with:
225+
226+
```shell
227+
CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info \
228+
accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes=7
229+
src/open_r1/grpo.py --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
207230
```
208231

209232
#### IOI problems
@@ -214,6 +237,7 @@ To get piston workers running, see [slurm/piston/README.md](./slurm/piston/READM
214237
Set your environment variable `PISTON_ENDPOINTS` to `slurm` or to a list of piston worker endpoints.
215238

216239
See the [example recipe](./recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml) for how to use the reward function:
240+
217241
```shell
218242
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
219243
--num_processes=7 src/open_r1/grpo.py \

0 commit comments

Comments
 (0)