Skip to content

vllm-dp-support #3303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 117 additions & 14 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class ScriptArguments:
Revision to use for the model. If not specified, the default branch will be used.
tensor_parallel_size (`int`, *optional*, defaults to `1`):
Number of tensor parallel workers to use.
dp_size (`int`, *optional*, defaults to `1`):
Number of data parallel workers to use.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host address to run the server on.
port (`int`, *optional*, defaults to `8000`):
Expand Down Expand Up @@ -185,6 +187,10 @@ class ScriptArguments:
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
dp_size: int = field(
default=1,
metadata={"help": "Number of data parallel workers to use."},
)
host: str = field(
default="0.0.0.0",
metadata={"help": "Host address to run the server on."},
Expand Down Expand Up @@ -226,6 +232,30 @@ class ScriptArguments:
)


def init_llm_instance(dp_rank, llm_instances, model, revision, tensor_parallel_size,
gpu_memory_utilization, dtype, enable_prefix_caching, max_model_len,
dp_size):
"""Initialize an LLM instance for a given data parallel rank."""
# Set environment variables for DP
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)

# Create LLM instance
llm = LLM(
model=model,
revision=revision,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
)

# Store LLM instance in the shared list
llm_instances[dp_rank] = llm

def main(script_args: ScriptArguments):
if not is_fastapi_available():
raise ImportError(
Expand All @@ -245,19 +275,50 @@ def main(script_args: ScriptArguments):
if not is_vllm_available():
raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.")

llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
)
# Set up DP configuration
dp_size = script_args.dp_size

# Initialize LLM instances for data parallelism
if dp_size > 1:
from multiprocessing import Process, Manager

# Use a manager to share data between processes
manager = Manager()
llm_instances = manager.list([None] * dp_size)

# Create and start DP processes
dp_processes = []

# Start processes for initializing LLM instances
for dp_rank in range(dp_size):
p = Process(
target=init_llm_instance,
args=(dp_rank, llm_instances, script_args.model, script_args.revision,
script_args.tensor_parallel_size, script_args.gpu_memory_utilization,
script_args.dtype, script_args.enable_prefix_caching,
script_args.max_model_len, dp_size,)
)
p.start()
dp_processes.append(p)

# Wait for all processes to complete initialization
for p in dp_processes:
p.join()

# Use the first LLM instance for the server
llm = llm_instances[0]
else:
# Single instance case
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
dtype=script_args.dtype,
enable_prefix_caching=script_args.enable_prefix_caching,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
)

app = FastAPI()

Expand All @@ -284,6 +345,22 @@ async def get_tensor_parallel_size():
```
"""
return {"tensor_parallel_size": llm.llm_engine.vllm_config.parallel_config.tensor_parallel_size}

@app.get("/get_dp_size/")
async def get_dp_size():
"""
Retrieves the data parallel size.

Returns:
`dict`:
A dictionary containing the data parallel size.

Example response:
```json
{"dp_size": 2}
```
"""
return {"dp_size": dp_size}

class GenerateRequest(BaseModel):
prompts: list[str]
Expand Down Expand Up @@ -340,7 +417,33 @@ async def generate(request: GenerateRequest):
max_tokens=request.max_tokens,
guided_decoding=guided_decoding,
)
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params)

if dp_size > 1:
# Distribute prompts across DP ranks
all_outputs = []
prompts = request.prompts
prompts_per_rank = (len(prompts) + dp_size - 1) // dp_size
# Prepare prompt batches for each DP rank
prompt_batches = []
for dp_rank in range(dp_size):
start_idx = dp_rank * prompts_per_rank
end_idx = min(start_idx + prompts_per_rank, len(prompts))
if start_idx < len(prompts):
prompt_batches.append(prompts[start_idx:end_idx])
else:
# Empty batch for this rank
prompt_batches.append([])

# Process each batch with its corresponding LLM instance
for dp_rank, batch in enumerate(prompt_batches):
if batch: # Skip empty batches
llm_instance = llm_instances[dp_rank]
batch_outputs = llm_instance.generate(batch, sampling_params=sampling_params)
all_outputs.extend(batch_outputs)
Comment on lines +438 to +442
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it sequential?

else:
# Single LLM case
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params)

completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
return {"completion_ids": completion_ids}

Expand Down
Loading