-
Notifications
You must be signed in to change notification settings - Fork 992
Description
Hi maintainers,
While reviewing the FMHA forward runner integration, I noticed two correctness issues that can break execution on multi-GPU setups (and can also create subtle stream/device mismatches):
- Hardcoded GPU selection (device_id = 0)
In run_fmha_fwd, the hardware info is pinned to GPU0:
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
If tensors (q/k/v/o/lse) live on a non-zero device, this will query the wrong SM count and may also lead to launching with incorrect hardware assumptions.
- Missing CUDA device guard (at::cuda::CUDAGuard) and stream/device alignment
The code uses at::cuda::getCurrentCUDAStream() at the end:
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
but does not guard/set the current device to match q.device() (or any input tensor). In multi-GPU scenarios, the “current device” may differ from the tensor’s device, leading to:
wrong stream/device being used
incorrect hw_info.device_id / sm_count
potential launch failures or silent misbehavior
Suggested fix
Use a device guard based on an input tensor (e.g., q) and set hw_info.device_id accordingly:
#include <ATen/cuda/CUDAGuard.h>
at::cuda::CUDAGuard device_guard(q.device());
const int dev = at::cuda::current_device();
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = dev;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(dev);
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
This ensures:
correct device is active
stream matches the tensor device context
hardware info queries the right GPU
Why this matters
Even if most users run single-GPU, multi-GPU is common in training/inference servers. Hardcoding GPU0 + missing guards can produce correctness issues that are hard to diagnose (especially when the failure is not immediate).
If you'd like, I can provide a small repro snippet that places q/k/v/o on cuda:1 and shows the mismatch.
Thanks!