Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/rl/unified/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments

5. Run simple rl loop
```
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
```
Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode,
which uses a unified model definition for trainer and generator.
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/rl/unified/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchtitan.experiments.rl.vllm_compat.simple_rl import (
compute_grpo_advantages,
compute_grpo_advantages_stable,
get_vllm_flash_attention_backend,
math_reward_function,
trivial_reward_function,
)
Expand Down Expand Up @@ -197,6 +198,7 @@ def update_weights(self, vllm_compat_state: dict) -> None:
seed=42, # Fixed seed for determinism
enforce_eager=True,
tensor_parallel_size=self.tp_size, # Explicitly single GPU
attention_config={"backend": get_vllm_flash_attention_backend()},
)
logger.info("Created new vLLM engine")
else:
Expand Down
8 changes: 6 additions & 2 deletions torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training.

Command to run:
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
"""
import asyncio
import logging
import os

os.environ["VLLM_BATCH_INVARIANT"] = "1"

import torch
from monarch.actor import this_host
Expand All @@ -27,6 +30,7 @@
from torchtitan.experiments.rl.unified.models.utils import ModelMode
from torchtitan.experiments.rl.vllm_compat.simple_rl import (
download_and_convert_model,
get_vllm_flash_attention_backend,
load_gsm8k_dataset,
)
from vllm.model_executor.layers.batch_invariant import (
Expand Down Expand Up @@ -63,7 +67,7 @@ async def main():
trainer_tp_size = 1
generator_tp_size = 1

init_batch_invariance()
init_batch_invariance(get_vllm_flash_attention_backend())
batch_invariant = vllm_is_batch_invariant()
mode = ModelMode.VLLM_COMPAT

Expand Down
10 changes: 8 additions & 2 deletions torchtitan/experiments/rl/vllm_compat/simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@

from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
from transformers import AutoConfig, AutoTokenizer

from vllm import LLM, SamplingParams

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.model_executor.layers.batch_invariant import init_batch_invariance

init_batch_invariance()

def get_vllm_flash_attention_backend() -> AttentionBackendEnum:
return AttentionBackendEnum.FLASH_ATTN


init_batch_invariance(get_vllm_flash_attention_backend())


class VLLMRolloutEngine:
Expand Down
Loading