@@ -227,7 +227,13 @@ DEVICE_INLINE void quantize_fp8_kv(
227
227
__half2* qparam = nullptr ,
228
228
bool do_norm = false );
229
229
230
- DEVICE_INLINE void per_row_norm (fx4& dst);
230
+ DEVICE_INLINE void per_row_norm (fx4& dst) {
231
+ float sum = fx4_dot (dst, dst);
232
+ // Warp reduce sum
233
+ sum = warpReduceSum (sum);
234
+ float rsqr = rsqrtf (sum / D_H);
235
+ dst = fx4_scale (dst, rsqr);
236
+ }
231
237
DEVICE_INLINE void per_row_amax (fx4& dst, float * amax);
232
238
DEVICE_INLINE void per_head_amax (fx4& dst, float * amax);
233
239
__global__ void nope_qkv_varseq_prefill_kernel (
@@ -2850,7 +2856,6 @@ at::Tensor quantize_qkv_per_head(
2850
2856
// HH += N_KVH_L * 2;
2851
2857
qparam_k_ptr = qparam_k.value ().data_ptr <float >();
2852
2858
qparam_v_ptr = qparam_v.value ().data_ptr <float >();
2853
- CHECK_EQ (HH, 7 );
2854
2859
}
2855
2860
auto num_warps = B_T * HH;
2856
2861
dim3 block_size (kThreadsPerWarp , kWarpsPerBlock );
@@ -2883,13 +2888,6 @@ at::Tensor quantize_qkv_per_head(
2883
2888
C10_CUDA_KERNEL_LAUNCH_CHECK ();
2884
2889
return scale_q;
2885
2890
}
2886
- DEVICE_INLINE void per_row_norm (fx4& dst) {
2887
- float sum = fx4_dot (dst, dst);
2888
- // Warp reduce sum
2889
- sum = warpReduceSum (sum);
2890
- float rsqr = rsqrtf (sum / D_H);
2891
- dst = fx4_scale (dst, rsqr);
2892
- }
2893
2891
2894
2892
DEVICE_INLINE void per_head_amax (fx4& dst, float * amax) {
2895
2893
dst = fx4_abs (dst);
@@ -2998,5 +2996,21 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
2998
2996
throw std::runtime_error (
2999
2997
" CUDA version is older than 12.0" ); // requires CUDA>=12
3000
2998
}
2999
+ at::Tensor quantize_qkv_per_head (
3000
+ at::Tensor xqkv_amax_row, // [B_T, HH]
3001
+ at::Tensor xqkv, // [B_T, HH, D_H]
3002
+ at::Tensor varseq_seqpos, // [B_T]
3003
+ std::optional<at::Tensor> varseq_batch, // [B_T]
3004
+ at::Tensor q_seqstarts, // [B+1]
3005
+ at::Tensor cache_K, // [B][MAX_T][N_KVH][D_H]
3006
+ at::Tensor cache_V, // [B][MAX_T][N_KVH][D_H]
3007
+ at::Tensor XQ_O, // [B_T][N_H][D]
3008
+ int64_t max_seq_length, // Length of the sequence
3009
+ std::optional<at::Tensor> qparam_k,
3010
+ std::optional<at::Tensor> qparam_v) {
3011
+ throw std::runtime_error (
3012
+ " CUDA version is older than 12.0" ); // requires CUDA>=12
3013
+ }
3014
+
3001
3015
#endif
3002
3016
} // namespace fbgemm_gpu
0 commit comments