-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
System Info
verl_v0.6.1 + vllm_v0.11.0 + megatron-core_v0.13.0
完整case由 H20 * 2node * 8cards运行 test_gspo_qwen30b_a3b_ep.sh 复现
详细分析由 3090 * 8cards跑裁剪版模型复现
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
1. Bug 概述
verl先使用megatron初始化world size,再进行rollout workers的初始化,再进行vllm实例的初始化。
其中rollout构建时的mesh,由于部分参数未向下传递到vllm初始化,导致vllm的并行策略划分和verl自身并行策略出现冲突。
以8卡,DP2 TP2 EP4为例:
- verl构建的mesh为
[[0,1,2,3], [4,5,6,7]],0-3为replica0,传入相同的batch分片,4-7为replica1,传入另外的batch分片。内部为TP*DP的展平维度。 - vllm构建的mesh为
[[0,1],[2,3],[4,5],[6,7]],共4个TP group,预期是传入4份不同的batch分片。 - 此时
[2,3]重复了[0,1]的计算且被丢弃,[6,7]重复了[4,5]的计算被丢弃,浪费了50%的算力。除此之外,EP未能在vllm中实际生效,MoE层不走all2all,而走ETP。
如果是2机16卡的 test_gspo_qwen30b_a3b_ep.sh 推荐配置 DP4 TP1 EP4,则有16个TP group,rank0-3中,仅rank0计算有效,rank1-3都重复计算且被丢弃,剩余3个replica同理,浪费75%的算力。
2. 源码追踪
2.1 verl框架 megatron_workers 下 _build_rollout()
展平输入的TP和DP维度,构建verl层级的mesh配置,dp 等同ExternalDP概念,infer_tp 为每个replica下需要的rank数。
is_collect 决定当前worker的数据是否收回single controller,只有 infer_tp_rank==0 & infer_pp_rank==0 的worker才被收集,即每个replica的0号rank被收集。
def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh
# 1. parse rollout and huggingface model config
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
infer_pp = self.config.rollout.pipeline_model_parallel_size
infer_world_size = infer_tp * infer_pp
dp = self.world_size // infer_world_size
assert self.world_size % infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
)
print(f"[DEBUG] [_build_rollout] world_size={self.world_size}, "
f"rollout TP={self.config.rollout.tensor_model_parallel_size}, "
f"rollout DP={self.config.rollout.data_parallel_size}, "
f"rollout PP={self.config.rollout.pipeline_model_parallel_size}, "
f"infer_tp={infer_tp}, infer_pp={infer_pp}, "
f"infer_world_size={infer_world_size}, "
f"ExternalDP(dp)={dp}")
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
)
is_collect = (
rollout_device_mesh["infer_tp"].get_local_rank() == 0
and rollout_device_mesh["infer_pp"].get_local_rank() == 0
)
print(f"[DEBUG] [_build_rollout] rollout_device_mesh={rollout_device_mesh}, "
f"my gen_dp_rank={rollout_device_mesh['dp'].get_local_rank()}, "
f"my infer_tp_rank={rollout_device_mesh['infer_tp'].get_local_rank()}, "
f"is_collect={is_collect}")
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
# 3. init trainer and rollout random states
self.torch_random_states = get_torch_device().get_rng_state()
gen_dp_rank = rollout_device_mesh["dp"].get_local_rank()
get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
# 4. build rollout model
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger)
self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger)
# 5. switch to trainer mode
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
# For sync mode, we directly switch to trainer mode here.
# For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.
if rollout_config.mode == "sync" and self._is_actor:
loop = get_event_loop()
loop.run_until_complete(self.trainer_mode())2.2 verl框架 vllm_rollout_spmd 下 vLLMRollout 实例初始化
可以看到,上面verl自己初始化的 device_mesh 没有在这里使用,而是直接传了来自config的TP size,DP和EP也都没有传入。
print(f"[DEBUG] [vLLMRollout.__init__] Creating LLM with: "
f"tensor_parallel_size={tensor_parallel_size}, "
f"distributed_executor_backend='external_launcher', "
f"model_path={model_path}, "
f"dtype={config.dtype}")
self.inference_engine = LLM(
model=model_path,
enable_sleep_mode=config.free_cache_engine,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="external_launcher",
dtype=config.dtype,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
disable_custom_all_reduce=True,
skip_tokenizer_init=False,
max_model_len=max_model_len,
max_num_seqs=config.max_num_seqs,
load_format=load_format,
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_prefix_caching=config.enable_prefix_caching,
trust_remote_code=trust_remote_code,
seed=config.get("seed", 0),
**compilation_config,
**self.lora_kwargs,
**engine_kwargs,
)2.3 在vllm内部 parallel_state.py 中看到最终初始化的vllm并行配置
def _distributed_args(self) -> tuple[str, int, int]:
# engines are launched in torchrun-compatible launchers
# so we can use the env:// method.
# required env vars:
# - RANK
# - LOCAL_RANK
# - MASTER_ADDR
# - MASTER_PORT
distributed_init_method = "env://"
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
print(f"[DEBUG] [inside VLLM] ExecutorWithExternalLauncher._distributed_args: "
f"RANK={rank}, LOCAL_RANK={local_rank}, "
f"MASTER_ADDR={os.environ.get('MASTER_ADDR')}, "
f"MASTER_PORT={os.environ.get('MASTER_PORT')}")
return distributed_init_method, rank, local_rank参考运行脚本(1层裁剪版)
点击展开完整脚本
set -x
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
# export TIKTOKEN_RS_CACHE_DIR=/datasets/tiktoken_rs_cache/
export PYTHONUNBUFFERED=1
export VLLM_USE_V1=1
export NCCL_DEBUG=WARN
export RAY_DEDUP_LOGS=0
time=$(date +%m%d_%H%M)
# max_prompt_length=$((1024 * 2))
# max_response_length=$((1024 * 8))
max_prompt_length=1024
max_response_length=1024
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
DATA_MODEL_CONFIG=" \
data.train_files=/data/verl_dataset/dapo-math-17k.parquet \
data.val_files=/data/verl_dataset/aime_2024/train.parquet \
data.prompt_key=prompt \
data.return_raw_chat=True \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.shuffle=True \
data.seed=42 \
actor_rollout_ref.model.path=/data/llm/Qwen3-30B-A3B-1L"
BS_CONFIG=" \
data.train_batch_size=16 \
actor_rollout_ref.rollout.n=2 \
actor_rollout_ref.actor.ppo_epochs=1 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len}"
ALGO_CONFIG=" \
actor_rollout_ref.actor.policy_loss.loss_mode=gspo \
algorithm.adv_estimator=grpo \
algorithm.use_kl_in_reward=False \
algorithm.kl_ctrl.kl_coef=0.0 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.0 \
actor_rollout_ref.actor.clip_ratio_low=3e-4 \
actor_rollout_ref.actor.clip_ratio_high=4e-4 \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.loss_agg_mode=token-mean"
ROLLOUT_CONFIG=" \
actor_rollout_ref.rollout.gpu_memory_utilization=0.40 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.top_p=1.0 \
actor_rollout_ref.rollout.top_k=-1 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
actor_rollout_ref.rollout.val_kwargs.top_k=-1 \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=sync \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.data_parallel_size=2 \
actor_rollout_ref.rollout.expert_parallel_size=4 \
actor_rollout_ref.rollout.enable_expert_parallel=True \
actor_rollout_ref.rollout.skip_rollout=False \
actor_rollout_ref.rollout.skip_dump_dir=/verl/workspace/rollout_cache_dump"
ACTOR_CONFIG=""
REF_CONFIG=""
REWARD_CONFIG=" \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=True \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=1024 \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=1.0 \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length}"
CRITIC_CONFIG=""
# Megatron 并行度配置(替代 FSDP_CONFIG)
MEGATRON_CONFIG=" \
actor_rollout_ref.actor.megatron.param_offload=True \
actor_rollout_ref.actor.megatron.optimizer_offload=True \
actor_rollout_ref.actor.megatron.grad_offload=True \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \
actor_rollout_ref.actor.megatron.use_mbridge=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True"
OPTIMIZER_CONFIG=" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.optim.clip_grad=1.0"
TRAIN_CONFIG=" \
trainer.nnodes=1 \
trainer.n_gpus_per_node=8 \
trainer.project_name=DAPO \
trainer.experiment_name=GSPO-Qwen3-30B-A3B-Base-MATH \
trainer.logger=['console','tensorboard'] \
trainer.val_before_train=False \
trainer.save_freq=-1 \
trainer.test_freq=-1 \
trainer.total_epochs=1 \
trainer.total_training_steps=5 \
trainer.resume_mode=auto"
# ==================== 执行脚本 ====================
SCRIPT=" \
python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name=ppo_megatron_trainer.yaml \
${DATA_MODEL_CONFIG} \
${BS_CONFIG} \
${ALGO_CONFIG} \
${ROLLOUT_CONFIG} \
${ACTOR_CONFIG} \
${REF_CONFIG} \
${REWARD_CONFIG} \
${CRITIC_CONFIG} \
${MEGATRON_CONFIG} \
${OPTIMIZER_CONFIG} \
${TRAIN_CONFIG} \
2>&1 | tee logs/case4/run_TX_2x8_GSPO_vllm_Megatron_10k_v0.6.1_${time}.log"
echo ${SCRIPT}
eval ${SCRIPT}Expected behavior
当前配置:8卡,TP=2, DP=2, EP=4
| 概念 | 计算公式 | 当前例子 |
|---|---|---|
| 传入的参数配置 | ||
| TP | rollout.tensor_model_parallel_size |
2 |
| DP | rollout.data_parallel_size |
2 |
| PP | rollout.pipeline_model_parallel_size |
1 |
| EP | rollout.expert_parallel_size |
4 |
verl 层 (_build_rollout) |
||
| infer_tp | TP × DP | 2 × 2 = 4 |
| infer_pp | PP | 1 |
| infer_world_size | infer_tp × infer_pp = TP × DP × PP | 4 × 1 = 4 |
| verl ExternalDP | W / infer_world_size = W / (TP × DP × PP) | 8 / 4 = 2 |
| device_mesh 形状 | (ExternalDP, infer_tp, infer_pp) | (2, 4, 1) |
| 每个 replica 的 GPU 数 | infer_world_size = TP × DP × PP | 4 |
| verl replica分组 | 数量等同ExternalDP | [[0,1,2,3], [4,5,6,7]] |
vLLM 层 (LLM() 实际接收) |
||
| vllm TP | TP(传入config内配置的参数) | 2 |
| vllm DP | 1(未传,默认值) | 1 |
| vllm PP | 1(未传,默认值) | 1 |
| vllm EP | False(未传,默认值) | False |
| vllm config world_size | vllm_TP × vllm_DP × vllm_PP = TP | 2 |
| torch world_size | W(复用早期Megatron init时的全局组) | 8 |
| vllm ExternalDP | torch_world_size / vllm_TP = W / TP | 8 / 2 = 4 |
| vllm TP 组 | 每 TP 个相邻 rank 一组 | [[0,1],[2,3],[4,5],[6,7]] |
验证日志
1. 对 _build_rollout() 的追踪log
传入config正确,infer_tp=4, infer_pp=1, device_mesh=(2, 4, 1),rank0123传入相同数据,rank4567传入相同数据,只有rank0和rank4的数据会收集。
[36m(WorkerDict pid=186803)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
[36m(WorkerDict pid=186803)[0m [DEBUG] [_build_rollout] rollout_device_mesh=DeviceMesh((dp=2, infer_tp=4, infer_pp=1), device: 'cuda', stride: (4, 1, 1)), my gen_dp_rank=0, my infer_tp_rank=2, is_collect=False
[36m(WorkerDict pid=186805)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
[36m(WorkerDict pid=186805)[0m [DEBUG] [_build_rollout] rollout_device_mesh=DeviceMesh((dp=2, infer_tp=4, infer_pp=1), device: 'cuda', stride: (4, 1, 1)), my gen_dp_rank=1, my infer_tp_rank=0, is_collect=True
[36m(WorkerDict pid=186810)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
[36m(WorkerDict pid=186810)[0m [DEBUG] [_build_rollout] rollout_device_mesh=DeviceMesh((dp=2, infer_tp=4, infer_pp=1), device: 'cuda', stride: (4, 1, 1)), my gen_dp_rank=1, my infer_tp_rank=2, is_collect=False
[36m(WorkerDict pid=186804)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
[36m(WorkerDict pid=186804)[0m [DEBUG] [_build_rollout] rollout_device_mesh=DeviceMesh((dp=2, infer_tp=4, infer_pp=1), device: 'cuda', stride: (4, 1, 1)), my gen_dp_rank=0, my infer_tp_rank=3, is_collect=False
[36m(WorkerDict pid=186801)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
[36m(WorkerDict pid=186801)[0m [DEBUG] [_build_rollout] rollout_device_mesh=DeviceMesh((dp=2, infer_tp=4, infer_pp=1), device: 'cuda', stride: (4, 1, 1)), my gen_dp_rank=0, my infer_tp_rank=0, is_collect=True
[36m(WorkerDict pid=186806)[0m [DEBUG] [_build_rollout] world_size=8, rollout TP=2, rollout DP=2, rollout PP=1, infer_tp=4, infer_pp=1, infer_world_size=4, ExternalDP(dp)=2
2. 对 vLLMRollout 实例初始化的追踪
只传了 tensor_parallel_size:
[36m(WorkerDict pid=186803)[0m [DEBUG] vLLMRollout.__init()__ config:
[36m(WorkerDict pid=186803)[0m RolloutConfig(_target_='', name='vllm', mode='sync', skip_tokenizer_init=True, temperature=1.0, top_k=-1, top_p=1.0, do_sample=True, n=2, over_sample_rate=0, prompt_length=1024, response_length=1024, dtype='bfloat16', gpu_memory_utilization=0.4, ignore_eos=False, enforce_eager=False, cudagraph_capture_sizes=None, free_cache_engine=True, data_parallel_size=2, expert_parallel_size=4, tensor_model_parallel_size=2, pipeline_model_parallel_size=1, max_num_batched_tokens=2048, val_kwargs=SamplingConfig(_target_='', temperature=1.0, top_k=-1, top_p=1.0, do_sample=True, n=1), max_model_len=None, max_num_seqs=1024, log_prob_micro_batch_size=None, log_prob_micro_batch_size_per_gpu=None, log_prob_use_dynamic_bsz=True, log_prob_max_token_len_per_gpu=6144, disable_log_stats=True, multi_stage_wake_up=False, engine_kwargs={'vllm': {}, 'sglang': {}}, calculate_log_probs=True, agent=AgentLoopConfig(_target_='', num_workers=8, default_agent_loop='single_turn_agent', agent_loop_config_path=None, custom_async_server=CustomAsyncServerConfig(_target_='', path=None, name=None)), trace=TraceConfig(_target_='', backend=None, token2text=False), multi_turn=MultiTurnConfig(_target_='', enable=False, max_assistant_turns=None, tool_config_path=None, max_user_turns=None, max_parallel_calls=1, max_tool_response_length=256, tool_response_truncate_side='middle', interaction_config_path=None, use_inference_chat_template=False, tokenization_sanity_check_mode='strict', format='hermes', num_repeat_rollouts=None), server=ServerConfig(_target_='', timeout=60.0, max_attempts=3, retry_delay=2.0, max_connections=1000, max_start_wait_time=300.0), prometheus=PrometheusConfig(_target_='', enable=False, port=9090, file='/tmp/ray/session_latest/metrics/prometheus/prometheus.yml', served_model_name='/data/llm/Qwen3-30B-A3B-1L'), update_weights_bucket_megabytes=512, skip_rollout=False, skip_dump_dir='/verl/workspace/rollout_cache_dump/Case4_test', profiler=ProfilerConfig(_target_='', tool=None, enable=False, all_ranks=False, ranks=[], save_path='outputs/profile', tool_config={'nsys': NsightToolConfig(_target_='', discrete=False), 'npu': NPUToolConfig(_target_='', discrete=False, contents=[], level='level1', analysis=True), 'torch': TorchProfilerToolConfig(_target_='', step_start=0, step_end=None), 'torch_memory': TorchMemoryToolConfig(_target_='', trace_alloc_max_entries=100000, stack_depth=32)}, global_tool_config=None), enable_chunked_prefill=True, enable_prefix_caching=True, load_format='dummy', layered_summon=False, layer_name_map={'qkv_layer_name': 'qkv', 'gate_proj_layer_name': 'gate_up'}, sglang_engine_mode='local', limit_images=None)
[36m(WorkerDict pid=186803)[0m [DEBUG] [vLLMRollout.__init__] Creating LLM with: tensor_parallel_size=2, distributed_executor_backend='external_launcher', model_path=/data/llm/Qwen3-30B-A3B-1L, dtype=bfloat16
3. 对vllm内部初始化的追踪
每个rank能通过环境变量获取到自己实际的rank,local_rank 则全都是0。
确实只传了TP size,其余的rank通过ExternalDP的形式复制了模型。
[36m(WorkerDict pid=186803)[0m [DEBUG] [vLLMRollout.__init__] Creating LLM with: tensor_parallel_size=2, distributed_executor_backend='external_launcher', model_path=/data/llm/Qwen3-30B-A3B-1L, dtype=bfloat16
[36m(WorkerDict pid=186803)[0m `torch_dtype` is deprecated! Use `dtype` instead!
[36m(WorkerDict pid=186803)[0m [DEBUG] [inside VLLM] ExecutorWithExternalLauncher._distributed_args: RANK=2, LOCAL_RANK=0, MASTER_ADDR=10.9.113.99, MASTER_PORT=33509
[36m(WorkerDict pid=186803)[0m [DEBUG] [inside VLLM] initialize_model_parallel: all_ranks shape=torch.Size([4, 1, 1, 2]), TP groups=[[0, 1], [2, 3], [4, 5], [6, 7]], data_parallel_size_from_config=1
4. 对worker上generate的输入输出的追踪
@GPUMemoryLogger(role="vllm rollout spmd", logger=logger)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
......
......
import torch.distributed as dist
rank = dist.get_rank()
# 打印每个rank的第一个response前10个token
non_pad_mask = (idx[0] != 151643) # 151643 是 pad token
actual_prompt = idx[0][non_pad_mask]
print(f"[DEBUG] [VERIFY] rank={rank}, "
f"batch_size={batch_size}, "
f"actual_prompt[70:80]={actual_prompt[70:80].tolist()}, "
f"actual_prompt_len={len(actual_prompt)}, "
f"response[0][:10]={response[0][:10].tolist()}, "
f"response[1][:10]={response[1][:10].tolist()}")
batch = TensorDict(
{
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
)可以看到,rank 0-3传入了相同的输入,且得到了相同的输出。rank 4-7传入了相同的输入,且得到了相同的输出。这符合verl device_mesh 的划分预期,但是vllm的4个TP group预期是输入4份不同的数据,此处产生矛盾。
[[0,1],[2,3],[4,5],[6,7]] 的4个group中,[2,3] 重复了 [0,1] 的计算且被丢弃,[6,7] 重复了 [4,5] 的计算被丢弃,浪费了50%的算力。
Training Progress: 0%| | 0/5 [00:00<?, ?it/s]
[36m(WorkerDict pid=186806)[0m [DEBUG] [VERIFY] rank=5, batch_size=16, actual_prompt[70:80]=[2182, 697, 4226, 389, 1181, 1828, 1555, 1283, 330, 16141], actual_prompt_len=86, response[0][:10]=[140474, 126457, 90201, 12268, 14830, 14830, 7180, 96610, 5147, 60093], response[1][:10]=[114323, 72963, 19820, 55988, 65459, 57950, 114323, 4513, 88, 112981]
[36m(WorkerDict pid=186802)[0m [DEBUG] [VERIFY] rank=1, batch_size=16, actual_prompt[70:80]=[11, 264, 293, 59, 31716, 58657, 12857, 400, 64, 3], actual_prompt_len=142, response[0][:10]=[140474, 126457, 90201, 72145, 59542, 2326, 82541, 25223, 39063, 59325], response[1][:10]=[141650, 80970, 5510, 137906, 304, 42297, 15255, 29970, 105869, 4642]
[36m(WorkerDict pid=186803)[0m [DEBUG] [VERIFY] rank=2, batch_size=16, actual_prompt[70:80]=[11, 264, 293, 59, 31716, 58657, 12857, 400, 64, 3], actual_prompt_len=142, response[0][:10]=[140474, 126457, 90201, 72145, 59542, 2326, 82541, 25223, 39063, 59325], response[1][:10]=[141650, 80970, 5510, 137906, 304, 42297, 15255, 29970, 105869, 4642]
[36m(WorkerDict pid=186805)[0m [DEBUG] [VERIFY] rank=4, batch_size=16, actual_prompt[70:80]=[2182, 697, 4226, 389, 1181, 1828, 1555, 1283, 330, 16141], actual_prompt_len=86, response[0][:10]=[140474, 126457, 90201, 12268, 14830, 14830, 7180, 96610, 5147, 60093], response[1][:10]=[114323, 72963, 19820, 55988, 65459, 57950, 114323, 4513, 88, 112981]
[36m(WorkerDict pid=186801)[0m [DEBUG] [VERIFY] rank=0, batch_size=16, actual_prompt[70:80]=[11, 264, 293, 59, 31716, 58657, 12857, 400, 64, 3], actual_prompt_len=142, response[0][:10]=[140474, 126457, 90201, 72145, 59542, 2326, 82541, 25223, 39063, 59325], response[1][:10]=[141650, 80970, 5510, 137906, 304, 42297, 15255, 29970, 105869, 4642]
[36m(WorkerDict pid=186810)[0m [DEBUG] [VERIFY] rank=6, batch_size=16, actual_prompt[70:80]=[2182, 697, 4226, 389, 1181, 1828, 1555, 1283, 330, 16141], actual_prompt_len=86, response[0][:10]=[140474, 126457, 90201, 12268, 14830, 14830, 7180, 96610, 5147, 60093], response[1][:10]=[114323, 72963, 19820, 55988, 65459, 57950, 114323, 4513, 88, 112981]
[36m(WorkerDict pid=186815)[0m [DEBUG] [VERIFY] rank=7, batch_size=16, actual_prompt[70:80]=[2182, 697, 4226, 389, 1181, 1828, 1555, 1283, 330, 16141], actual_prompt_len=86, response[0][:10]=[140474, 126457, 90201, 12268, 14830, 14830, 7180, 96610, 5147, 60093], response[1][:10]=[114323, 72963, 19820, 55988, 65459, 57950, 114323, 4513, 88, 112981]
[36m(WorkerDict pid=186804)[0m [DEBUG] [VERIFY] rank=3, batch_size=16, actual_prompt[70:80]=[11, 264, 293, 59, 31716, 58657, 12857, 400, 64, 3], actual_prompt_len=142, response[0][:10]=[140474, 126457, 90201, 72145, 59542, 2326, 82541, 25223, 39063, 59325], response[1][:10]=[141650, 80970, 5510, 137906, 304, 42297, 15255, 29970, 105869, 4642]