Skip to content

Commit 693a56e

Browse files
committed
Make frequency table cache per-device for multi-GPU safety
Replace the single global s_freq_cache with a per-device std::unordered_map<int, FreqCacheEntry> keyed by CUDA device ID. This prevents crashes when the kernel is called on different GPUs within the same process (the old single cache held a device pointer that was only valid on the GPU that allocated it). AI-assisted. Made-with: Cursor
1 parent ddfdbe9 commit 693a56e

1 file changed

Lines changed: 32 additions & 24 deletions

File tree

include/flashinfer/fused_qk_rmsnorm_rope.cuh

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cstdio>
2727
#include <cstdlib>
2828
#include <stdexcept>
29+
#include <unordered_map>
2930

3031
namespace flashinfer {
3132

@@ -627,7 +628,7 @@ static inline float compute_adjusted_freq_host(int half_dim_val, float base, int
627628
// The cache is allocated once and never freed to ensure cudagraph capture compatibility.
628629
////////////////////////////////////////////////////////////////////////////////////////////////////
629630

630-
static struct {
631+
struct FreqCacheEntry {
631632
float* d_ptr = nullptr;
632633
int alloc_floats = 0;
633634
int head_dim = 0;
@@ -638,7 +639,11 @@ static struct {
638639
int num_frame_channels = 0;
639640
int num_height_channels = 0;
640641
int num_width_channels = 0;
641-
} s_freq_cache;
642+
};
643+
644+
// Per-device frequency table cache. Keyed by CUDA device ID so that
645+
// multi-GPU usage within a single process is safe.
646+
static std::unordered_map<int, FreqCacheEntry> s_freq_cache_map;
642647

643648
inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out, void* v_out,
644649
int64_t const num_tokens, int const seq_len, int const ppf,
@@ -657,22 +662,26 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
657662
FLASHINFER_FUSED_CHECK(attention_factor == 1.0f);
658663
}
659664

665+
int device_id;
666+
cudaGetDevice(&device_id);
667+
FreqCacheEntry& cache = s_freq_cache_map[device_id];
668+
660669
int const table_size = head_dim / 2;
661670

662-
if (s_freq_cache.alloc_floats < table_size) {
663-
if (s_freq_cache.d_ptr != nullptr) {
664-
cudaFree(s_freq_cache.d_ptr);
671+
if (cache.alloc_floats < table_size) {
672+
if (cache.d_ptr != nullptr) {
673+
cudaFree(cache.d_ptr);
665674
}
666-
cudaMalloc(&s_freq_cache.d_ptr, table_size * sizeof(float));
667-
s_freq_cache.alloc_floats = table_size;
675+
cudaMalloc(&cache.d_ptr, table_size * sizeof(float));
676+
cache.alloc_floats = table_size;
668677
}
669678

670679
bool cache_miss =
671-
(s_freq_cache.head_dim != head_dim || s_freq_cache.base != base ||
672-
s_freq_cache.factor != factor || s_freq_cache.low != low || s_freq_cache.high != high ||
673-
s_freq_cache.num_frame_channels != num_frame_channels ||
674-
s_freq_cache.num_height_channels != num_height_channels ||
675-
s_freq_cache.num_width_channels != num_width_channels);
680+
(cache.head_dim != head_dim || cache.base != base || cache.factor != factor ||
681+
cache.low != low || cache.high != high ||
682+
cache.num_frame_channels != num_frame_channels ||
683+
cache.num_height_channels != num_height_channels ||
684+
cache.num_width_channels != num_width_channels);
676685

677686
if (cache_miss) {
678687
FLASHINFER_FUSED_CHECK(table_size <= 128);
@@ -693,17 +702,16 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
693702

694703
FLASHINFER_FUSED_CHECK(offset == table_size);
695704

696-
cudaMemcpy(s_freq_cache.d_ptr, h_freq_table, table_size * sizeof(float),
697-
cudaMemcpyHostToDevice);
698-
699-
s_freq_cache.head_dim = head_dim;
700-
s_freq_cache.base = base;
701-
s_freq_cache.factor = factor;
702-
s_freq_cache.low = low;
703-
s_freq_cache.high = high;
704-
s_freq_cache.num_frame_channels = num_frame_channels;
705-
s_freq_cache.num_height_channels = num_height_channels;
706-
s_freq_cache.num_width_channels = num_width_channels;
705+
cudaMemcpy(cache.d_ptr, h_freq_table, table_size * sizeof(float), cudaMemcpyHostToDevice);
706+
707+
cache.head_dim = head_dim;
708+
cache.base = base;
709+
cache.factor = factor;
710+
cache.low = low;
711+
cache.high = high;
712+
cache.num_frame_channels = num_frame_channels;
713+
cache.num_height_channels = num_height_channels;
714+
cache.num_width_channels = num_width_channels;
707715
}
708716

709717
int const maxHeads = max(max(num_heads_q, num_heads_k), num_heads_v);
@@ -723,7 +731,7 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
723731
<<<gridDim, blockDim, 0, stream>>>( \
724732
reinterpret_cast<__nv_bfloat16 const*>(qkv_in), q_out, k_out, v_out, num_heads_q, \
725733
num_heads_k, num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), \
726-
reinterpret_cast<__nv_bfloat16 const*>(k_weight), s_freq_cache.d_ptr, num_tokens, \
734+
reinterpret_cast<__nv_bfloat16 const*>(k_weight), cache.d_ptr, num_tokens, \
727735
IntFastDiv(seq_len), IntFastDiv(ppw), IntFastDiv(pph * ppw), num_frame_channels, \
728736
num_height_channels, num_width_channels, attention_factor, is_qk_norm, \
729737
output_quant_scale, v_quant_scale); \

0 commit comments

Comments
 (0)