2626#include < cstdio>
2727#include < cstdlib>
2828#include < stdexcept>
29+ #include < unordered_map>
2930
3031namespace flashinfer {
3132
@@ -627,7 +628,7 @@ static inline float compute_adjusted_freq_host(int half_dim_val, float base, int
627628// The cache is allocated once and never freed to ensure cudagraph capture compatibility.
628629// //////////////////////////////////////////////////////////////////////////////////////////////////
629630
630- static struct {
631+ struct FreqCacheEntry {
631632 float * d_ptr = nullptr ;
632633 int alloc_floats = 0 ;
633634 int head_dim = 0 ;
@@ -638,7 +639,11 @@ static struct {
638639 int num_frame_channels = 0 ;
639640 int num_height_channels = 0 ;
640641 int num_width_channels = 0 ;
641- } s_freq_cache;
642+ };
643+
644+ // Per-device frequency table cache. Keyed by CUDA device ID so that
645+ // multi-GPU usage within a single process is safe.
646+ static std::unordered_map<int , FreqCacheEntry> s_freq_cache_map;
642647
643648inline void launchFusedQKNormRope (void const * qkv_in, void * q_out, void * k_out, void * v_out,
644649 int64_t const num_tokens, int const seq_len, int const ppf,
@@ -657,22 +662,26 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
657662 FLASHINFER_FUSED_CHECK (attention_factor == 1 .0f );
658663 }
659664
665+ int device_id;
666+ cudaGetDevice (&device_id);
667+ FreqCacheEntry& cache = s_freq_cache_map[device_id];
668+
660669 int const table_size = head_dim / 2 ;
661670
662- if (s_freq_cache .alloc_floats < table_size) {
663- if (s_freq_cache .d_ptr != nullptr ) {
664- cudaFree (s_freq_cache .d_ptr );
671+ if (cache .alloc_floats < table_size) {
672+ if (cache .d_ptr != nullptr ) {
673+ cudaFree (cache .d_ptr );
665674 }
666- cudaMalloc (&s_freq_cache .d_ptr , table_size * sizeof (float ));
667- s_freq_cache .alloc_floats = table_size;
675+ cudaMalloc (&cache .d_ptr , table_size * sizeof (float ));
676+ cache .alloc_floats = table_size;
668677 }
669678
670679 bool cache_miss =
671- (s_freq_cache .head_dim != head_dim || s_freq_cache .base != base ||
672- s_freq_cache. factor != factor || s_freq_cache. low != low || s_freq_cache .high != high ||
673- s_freq_cache .num_frame_channels != num_frame_channels ||
674- s_freq_cache .num_height_channels != num_height_channels ||
675- s_freq_cache .num_width_channels != num_width_channels);
680+ (cache .head_dim != head_dim || cache .base != base || cache. factor != factor ||
681+ cache. low != low || cache .high != high ||
682+ cache .num_frame_channels != num_frame_channels ||
683+ cache .num_height_channels != num_height_channels ||
684+ cache .num_width_channels != num_width_channels);
676685
677686 if (cache_miss) {
678687 FLASHINFER_FUSED_CHECK (table_size <= 128 );
@@ -693,17 +702,16 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
693702
694703 FLASHINFER_FUSED_CHECK (offset == table_size);
695704
696- cudaMemcpy (s_freq_cache.d_ptr , h_freq_table, table_size * sizeof (float ),
697- cudaMemcpyHostToDevice);
698-
699- s_freq_cache.head_dim = head_dim;
700- s_freq_cache.base = base;
701- s_freq_cache.factor = factor;
702- s_freq_cache.low = low;
703- s_freq_cache.high = high;
704- s_freq_cache.num_frame_channels = num_frame_channels;
705- s_freq_cache.num_height_channels = num_height_channels;
706- s_freq_cache.num_width_channels = num_width_channels;
705+ cudaMemcpy (cache.d_ptr , h_freq_table, table_size * sizeof (float ), cudaMemcpyHostToDevice);
706+
707+ cache.head_dim = head_dim;
708+ cache.base = base;
709+ cache.factor = factor;
710+ cache.low = low;
711+ cache.high = high;
712+ cache.num_frame_channels = num_frame_channels;
713+ cache.num_height_channels = num_height_channels;
714+ cache.num_width_channels = num_width_channels;
707715 }
708716
709717 int const maxHeads = max (max (num_heads_q, num_heads_k), num_heads_v);
@@ -723,7 +731,7 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
723731 <<<gridDim , blockDim , 0 , stream>>> ( \
724732 reinterpret_cast <__nv_bfloat16 const *>(qkv_in), q_out, k_out, v_out, num_heads_q, \
725733 num_heads_k, num_heads_v, eps, reinterpret_cast <__nv_bfloat16 const *>(q_weight), \
726- reinterpret_cast <__nv_bfloat16 const *>(k_weight), s_freq_cache .d_ptr , num_tokens, \
734+ reinterpret_cast <__nv_bfloat16 const *>(k_weight), cache .d_ptr , num_tokens, \
727735 IntFastDiv (seq_len), IntFastDiv (ppw), IntFastDiv (pph * ppw), num_frame_channels, \
728736 num_height_channels, num_width_channels, attention_factor, is_qk_norm, \
729737 output_quant_scale, v_quant_scale); \
0 commit comments