Skip to content

[Bug/Correctness] Hardcoded device_id=0 + missing CUDAGuard can break multi-GPU correctness (wrong hw_info / stream mismatch) #158

@red1239109-cmd

Description

@red1239109-cmd

Hi maintainers,

While reviewing the FMHA forward runner integration, I noticed two correctness issues that can break execution on multi-GPU setups (and can also create subtle stream/device mismatches):

  1. Hardcoded GPU selection (device_id = 0)

In run_fmha_fwd, the hardware info is pinned to GPU0:

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

If tensors (q/k/v/o/lse) live on a non-zero device, this will query the wrong SM count and may also lead to launching with incorrect hardware assumptions.

  1. Missing CUDA device guard (at::cuda::CUDAGuard) and stream/device alignment

The code uses at::cuda::getCurrentCUDAStream() at the end:

CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));

but does not guard/set the current device to match q.device() (or any input tensor). In multi-GPU scenarios, the “current device” may differ from the tensor’s device, leading to:

wrong stream/device being used

incorrect hw_info.device_id / sm_count

potential launch failures or silent misbehavior

Suggested fix

Use a device guard based on an input tensor (e.g., q) and set hw_info.device_id accordingly:

#include <ATen/cuda/CUDAGuard.h>

at::cuda::CUDAGuard device_guard(q.device());
const int dev = at::cuda::current_device();

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = dev;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(dev);

CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));

This ensures:

correct device is active

stream matches the tensor device context

hardware info queries the right GPU

Why this matters

Even if most users run single-GPU, multi-GPU is common in training/inference servers. Hardcoding GPU0 + missing guards can produce correctness issues that are hard to diagnose (especially when the failure is not immediate).

If you'd like, I can provide a small repro snippet that places q/k/v/o on cuda:1 and shows the mismatch.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions