Skip to content

[Bugfix] fix quickreduce acc error in cudagraph mode#29508

Open
haoyangli0109 wants to merge 1 commit into
sgl-project:mainfrom
haoyangli0109:lhy/fix_qr_graph
Open

[Bugfix] fix quickreduce acc error in cudagraph mode#29508
haoyangli0109 wants to merge 1 commit into
sgl-project:mainfrom
haoyangli0109:lhy/fix_qr_graph

Conversation

@haoyangli0109

@haoyangli0109 haoyangli0109 commented Jun 27, 2026

Copy link
Copy Markdown
Contributor

1. cause:
Once flag_color is fixed by graph, it remains unchanged for each round
→ The written flag value repeats in each round and cannot be distinguished from the residual value of the previous round → The waiting party is prematurely satisfied by the old value and is immediately granted access
→ At this point, since the data for the current round has not yet been fully transmitted, the system reads the old data next, resulting in an error

2. repro code

import argparse
import multiprocessing
import os

# quick-reduce reads these envs when QuickAllReduce is constructed; set them
# before importing vllm so the (cached) env values are correct.
# FP = lossless regime so the all-reduce is bit-exact for small fp16 integers.
os.environ.setdefault("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "FP")
os.environ.setdefault("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "0")

import torch
import torch.distributed as dist

from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce


def worker(rank, world_size):
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # gloo (CPU) group: QuickAllReduce must be attached to a non-NCCL group; it
    # is used only for the one-time IPC-handle exchange.
    dist.init_process_group(
        backend="gloo",
        init_method="tcp://127.0.0.1:29500",
        rank=rank,
        world_size=world_size,
    )

    qr = QuickAllReduce(group=dist.group.WORLD, device=device)
    assert not qr.disabled, (
        "quick-reduce unavailable on this arch/env "
        "(needs ROCm MI300 gfx94/gfx95, even GPU count, same node, "
        "and a non-NONE VLLM_ROCM_QUICK_REDUCE_QUANTIZATION)."
    )

    regime = os.environ["VLLM_ROCM_QUICK_REDUCE_QUANTIZATION"]
    N = 1 << 21  # 2M fp16 = 4 MB, above qr thresholds for the direct call path
    inp = torch.empty(N, dtype=torch.float16, device=device)
    out = torch.empty(N, dtype=torch.float16, device=device)

    # Every rank contributes the SAME value v in a round, so the true cross-rank
    # all-reduce sum is simply world * v.
    def expected(v):
        return float(world_size * v)

    if rank == 0:
        print(
            f"[repro] world_size={world_size} elems={N} regime={regime} fp16",
            flush=True,
        )

    # Warmup, then capture a graph with EXACTLY ONE quick-reduce (isolated qr,
    # so it is the sole writer of its flag slot -- the condition that triggers
    # the stale-flag bug).
    inp.fill_(1.0)
    qr.quick_all_reduce(inp, out=out)
    torch.cuda.synchronize()
    dist.barrier()

    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        qr.quick_all_reduce(inp, out=out)
    torch.cuda.synchronize()
    dist.barrier()

    for v in range(10):
        inp.fill_(float(v))  # in-place: same value on every rank
        dist.barrier()
        g.replay()
        torch.cuda.synchronize()
        dist.barrier()
        got = out.float()
        expect = expected(v)
        if rank == 0:
            print(f"round {v}: got={got[:10]}, expected={expect}", flush=True)

    dist.destroy_process_group()


def run_multiprocessing(world_size):
    processes = []
    for rank in range(world_size):
        p = multiprocessing.Process(target=worker, args=(rank, world_size))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--world_size",
        type=int,
        default=4,
        help="number of processes / GPUs to use",
    )
    args = parser.parse_args()

    multiprocessing.set_start_method("spawn")
    run_multiprocessing(world_size=args.world_size)


3.solution
Refer to customallreduce and use a pointer to maintain the flag_color for each block, passing it as a pointer to the device side for execution.

4.This change will not affect performance.


CI States

Latest PR Test (Base): ❌ Run #28286525109
Latest PR Test (Extra): ❌ Run #28286525046

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request moves the flag color counter logic from the host to the device using a new d_flag_counters buffer, enabling CUDA-graph replays to correctly advance the flag color inside the kernel. The reviewer suggests extending the leak-prevention logic in destroy() to also free dbuffer and dbuffer_list independently of the initialized flag, as they are also prone to leaking during partial initialization failures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread sgl-kernel/csrc/allreduce/quick_all_reduce.h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant