Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,60 @@ def update_named_param(self, name: str, weights: torch.Tensor):
self.communicator.broadcast(weights, src=self.rank)
self.communicator.group.barrier()

def batch_update_named_params(self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None):
"""
Updates multiple named parameters in a single batch, reducing HTTP round-trips.

Sends parameter metadata via HTTP POST, then broadcasts each tensor via NCCL in sequence.
When chunk_size is set, splits params into chunks whose total element count doesn't exceed
the limit, avoiding large HTTP requests regardless of individual parameter sizes.

Args:
params: List of (name, weights_tensor) tuples.
chunk_size: Max total elements per HTTP call. None = all in one call. Use smaller
values for large models to avoid HTTP request size limits.
"""
if chunk_size is None:
chunks = [params]
else:
chunks = []
current_chunk = []
current_elements = 0
for name, weights in params:
n = weights.numel()
if current_chunk and current_elements + n > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n
if current_chunk:
chunks.append(current_chunk)

for chunk in chunks:
# Send metadata for this chunk
param_metadata = [
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

# Broadcast each tensor via NCCL
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)

# Barrier after each chunk
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()

def update_model_params(self, model: nn.Module):
"""
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
Expand Down
24 changes: 23 additions & 1 deletion trl/generation/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ class VLLMGeneration:
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port
is occupied, there is no need to change it.
weight_sync_chunk_size (`int` or `None`, *optional*, defaults to `None`):
Maximum total tensor elements per HTTP request during batched weight sync to the vLLM server. `None`
(default) sends all parameters in a single request. Set to a smaller value (e.g. `100_000_000`) for
large models to avoid exceeding HTTP request size limits.

> Parameters for "colocate" vLLM mode:

Expand Down Expand Up @@ -245,6 +249,7 @@ def __init__(
server_port: int = 8000,
server_timeout: float = 240.0,
group_port: int = 51216,
weight_sync_chunk_size: int | None = None,
# Colocate mode configuration
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
Expand Down Expand Up @@ -280,8 +285,9 @@ def __init__(
self.server_base_url = server_base_url
self.server_host = server_host
self.server_port = server_port
self.group_port = group_port
self.server_timeout = server_timeout
self.group_port = group_port
self.weight_sync_chunk_size = weight_sync_chunk_size

# Colocate mode configuration
self.tensor_parallel_size = tensor_parallel_size
Expand Down Expand Up @@ -500,8 +506,20 @@ def sync_weights(self):
self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(model)
elif not zero_stage_3 and self.mode == "server" and accelerator.is_main_process:
params = []
for name, param in model.named_parameters():
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
if model.prefix in name:
continue
if "original_module" in name:
continue
name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."])
params.append((name, param.data))
self.vllm_client.batch_update_named_params(params, chunk_size=self.weight_sync_chunk_size)
else:
# DeepSpeed ZeRO-3 with PEFT
# ZeRO-3 gathers per-param so we can't easily batch sync them to vLLM
for name, param in model.named_parameters():
# When using PEFT, we need to recover the original parameter name
name = name.removeprefix("base_model.model.").replace(".base_layer", "")
Expand Down Expand Up @@ -530,6 +548,10 @@ def sync_weights(self):
self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP
elif fsdp_version == 2:
self._sync_fsdp2_params_to_vllm(model)
elif not zero_stage_3 and self.mode == "server" and accelerator.is_main_process:
params = [(self._fix_param_name_to_vllm(name), param.data)
for name, param in model.named_parameters()]
self.vllm_client.batch_update_named_params(params, chunk_size=self.weight_sync_chunk_size)
else:
for name, param in model.named_parameters():
name = self._fix_param_name_to_vllm(name)
Expand Down
57 changes: 53 additions & 4 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import asyncio
import base64
import logging
import os
Expand Down Expand Up @@ -185,6 +186,38 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non
# Load the received weights into the model.
self.model_runner.model.load_weights(weights=[(name, weight)])

def batch_update_named_params(self, params: list[tuple[str, str, tuple[int, ...]]]) -> None:
"""
Receives and updates multiple named parameters in a single batch.
This avoids per-parameter HTTP round-trips. The client broadcasts each
parameter's tensor via NCCL in the same order as the params list.
Args:
params: List of (name, dtype, shape) tuples for each parameter.
"""
if self.communicator is None:
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")

weights_to_load = []
for name, dtype_str, shape in params:
dtype = getattr(torch, dtype_str.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)

if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
else:
self.communicator.broadcast(weight, src=self.client_rank)

weights_to_load.append((name, weight))

# Single barrier after all broadcasts
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()

# Load all weights at once
self.model_runner.model.load_weights(weights=weights_to_load)

def close_communicator(self) -> None:
"""
Closes the communicator when weight synchronization is no longer needed.
Expand Down Expand Up @@ -852,8 +885,9 @@ async def init_communicator(request: InitCommunicatorRequest):
"method": "init_communicator",
"args": (request.host, request.port, world_size, request.client_device_uuid),
}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))

return {"message": "Request received, initializing communicator"}

Expand All @@ -880,11 +914,26 @@ async def update_named_param(request: UpdateWeightsRequest):
# So with collective_rpc we need to call it this way:
# llm.collective_rpc("update_named_param", args=("name", "torch.float32", (10, 10)))
kwargs = {"method": "update_named_param", "args": (request.name, request.dtype, tuple(request.shape))}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))

return {"message": "Request received, updating named parameter"}

class BatchUpdateWeightsRequest(BaseModel):
params: list[dict] # List of {"name": str, "dtype": str, "shape": list[int]}

@app.post("/batch_update_named_params/")
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
"""Batch update: sends all param metadata in one HTTP call, then NCCL broadcasts happen in sequence."""

params_list = [(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params]
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(*(loop.run_in_executor(None, conn.send, msg) for conn in connections))
return {"message": f"Batch update started for {len(params_list)} parameters"}

@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
"""
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class GRPOConfig(_BaseConfig):
vllm_group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port
is occupied, there is no need to change it.
vllm_weight_sync_chunk_size (`int` or `None`, *optional*, defaults to `None`):
Maximum number of total tensor elements per HTTP request during batched weight sync to the vLLM
server. `None` (default) sends all parameters in a single request. Set to a smaller value (e.g.
`500_000_000`) for large models to avoid exceeding HTTP request size limits.

> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)

Expand Down Expand Up @@ -538,6 +542,14 @@ class GRPOConfig(_BaseConfig):
"Unless the port is occupied, there is no need to change it.",
},
)
vllm_weight_sync_chunk_size: int | None = field(
default=None,
metadata={
"help": "Maximum total tensor elements per HTTP request during batched weight sync to the vLLM server. "
"None (default) sends all parameters in a single request. Set to a smaller value (e.g. 500_000_000) "
"for large models to avoid exceeding HTTP request size limits."
},
)

# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,9 @@ def rollout_func(prompts):
server_base_url=args.vllm_server_base_url,
server_host=args.vllm_server_host,
server_port=args.vllm_server_port,
group_port=args.vllm_group_port,
server_timeout=args.vllm_server_timeout,
group_port=args.vllm_group_port,
weight_sync_chunk_size=args.vllm_weight_sync_chunk_size,
# Colocate mode configuration
tensor_parallel_size=args.vllm_tensor_parallel_size,
gpu_memory_utilization=args.vllm_gpu_memory_utilization,
Expand Down