Skip to content

[BUG] Batch inference DDP + zero stage 3 = inference code hangs #7128

Open
@ShengYun-Peng

Description

@ShengYun-Peng

I ran the batch inference code with deepspeed generation, not the vllm one. The code hangs while I set zero stage = 3. I created a minimal code snippet for you to debug the error.

import os

import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer

import deepspeed

# Initialize distributed environment
def setup_distributed():
    dist.init_process_group(backend="nccl", init_method="env://")
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)
    return local_rank


def load_model(model_name="facebook/opt-1.3b", local_rank=0):
    # Ensure distributed environment is set up
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl", init_method="env://")

    world_size = dist.get_world_size()  # Number of GPUs available
    torch.cuda.set_device(local_rank)  # Assign each process to a GPU

    print(
        f"Loading model {model_name} on rank {local_rank}, using {world_size} GPUs for model parallelism"
    )

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # ✅ DeepSpeed Inference config for Model Parallelism
    ds_config = {
        # "replace_with_kernel_inject": False,  # Enables optimized inference kernels
        "tensor_parallel": {"tp_size": 1},  # Enables Model Parallelism
        "dtype": "bf16"
        if torch.cuda.is_bf16_supported()
        else "fp16",  # Automatic dtype selection
    }

    # ✅ Initialize DeepSpeed for Model Parallel Inference
    model = deepspeed.init_inference(model, config=ds_config)

    return model, tokenizer


# Perform inference with data parallelism
def batch_inference(model, tokenizer, prompts, local_rank):
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(
        f"cuda:{local_rank}"
    )
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=150, synced_gpus=True)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


def main():
    local_rank = setup_distributed()
    model, tokenizer = load_model(local_rank=local_rank)

    # Each GPU gets a different batch
    global_batch = [
        [
            "What is AI?",
            "Explain deep learning.",
        ],  # Batch for GPU 0
        [
            "Tell me a joke.",
            "What is reinforcement learning? Tell me all the details",
        ],  # Batch for GPU 1
    ]
    prompts = global_batch[local_rank] if local_rank < len(global_batch) else []

    print(f"GPU {local_rank} prompts:", prompts)
    # Perform batch inference
    results = batch_inference(model, tokenizer, prompts, local_rank)
    print(f"GPU {local_rank} results:", results)

    dist.barrier()  # Ensure all GPUs finish


if __name__ == "__main__":
    main()

Run the code with

NCCL_DEBUG=INFO NCCL_BLOCKING_WAIT=1 NCCL_ASYNC_ERROR_HANDLING=1 deepspeed --num_gpus 2 test_deepspeed.py

The code should run without error because it's DDP.
Now, if we change set "tensor_parallel": {"tp_size": 1} -> "tensor_parallel": {"tp_size": 2} and rerun the code. The code hangs forever. Note that the bug happens when DDP + TP are enabled.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions