diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index fa54a936da..d93e9a4e76 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,7 +66,7 @@ python torchtitan/experiments/rl/unified/infer.py --model torchtitan/experiments 5. Run simple rl loop ``` -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, which uses a unified model definition for trainer and generator. diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index d0ee5cf38f..ce6f83e643 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -18,6 +18,7 @@ from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_grpo_advantages, compute_grpo_advantages_stable, + get_vllm_flash_attention_backend, math_reward_function, trivial_reward_function, ) @@ -197,6 +198,7 @@ def update_weights(self, vllm_compat_state: dict) -> None: seed=42, # Fixed seed for determinism enforce_eager=True, tensor_parallel_size=self.tp_size, # Explicitly single GPU + attention_config={"backend": get_vllm_flash_attention_backend()}, ) logger.info("Created new vLLM engine") else: diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 087e4f1e70..3d93b8adbe 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -14,10 +14,13 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py """ import asyncio import logging +import os + +os.environ["VLLM_BATCH_INVARIANT"] = "1" import torch from monarch.actor import this_host @@ -27,6 +30,7 @@ from torchtitan.experiments.rl.unified.models.utils import ModelMode from torchtitan.experiments.rl.vllm_compat.simple_rl import ( download_and_convert_model, + get_vllm_flash_attention_backend, load_gsm8k_dataset, ) from vllm.model_executor.layers.batch_invariant import ( @@ -63,7 +67,7 @@ async def main(): trainer_tp_size = 1 generator_tp_size = 1 - init_batch_invariance() + init_batch_invariance(get_vllm_flash_attention_backend()) batch_invariant = vllm_is_batch_invariant() mode = ModelMode.VLLM_COMPAT diff --git a/torchtitan/experiments/rl/vllm_compat/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py index 5e1fdd486b..4a7a50455a 100644 --- a/torchtitan/experiments/rl/vllm_compat/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -36,11 +36,17 @@ from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig, AutoTokenizer - from vllm import LLM, SamplingParams + +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.model_executor.layers.batch_invariant import init_batch_invariance -init_batch_invariance() + +def get_vllm_flash_attention_backend() -> AttentionBackendEnum: + return AttentionBackendEnum.FLASH_ATTN + + +init_batch_invariance(get_vllm_flash_attention_backend()) class VLLMRolloutEngine: