Skip to content

Commit acd35ed

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
1 parent 6e9f1e0 commit acd35ed

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

+23-9
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,13 @@ DEVICE_INLINE void quantize_fp8_kv(
227227
__half2* qparam = nullptr,
228228
bool do_norm = false);
229229

230-
DEVICE_INLINE void per_row_norm(fx4& dst);
230+
DEVICE_INLINE void per_row_norm(fx4& dst) {
231+
float sum = fx4_dot(dst, dst);
232+
// Warp reduce sum
233+
sum = warpReduceSum(sum);
234+
float rsqr = rsqrtf(sum / D_H);
235+
dst = fx4_scale(dst, rsqr);
236+
}
231237
DEVICE_INLINE void per_row_amax(fx4& dst, float* amax);
232238
DEVICE_INLINE void per_head_amax(fx4& dst, float* amax);
233239
__global__ void nope_qkv_varseq_prefill_kernel(
@@ -2850,7 +2856,6 @@ at::Tensor quantize_qkv_per_head(
28502856
// HH += N_KVH_L * 2;
28512857
qparam_k_ptr = qparam_k.value().data_ptr<float>();
28522858
qparam_v_ptr = qparam_v.value().data_ptr<float>();
2853-
CHECK_EQ(HH, 7);
28542859
}
28552860
auto num_warps = B_T * HH;
28562861
dim3 block_size(kThreadsPerWarp, kWarpsPerBlock);
@@ -2883,13 +2888,6 @@ at::Tensor quantize_qkv_per_head(
28832888
C10_CUDA_KERNEL_LAUNCH_CHECK();
28842889
return scale_q;
28852890
}
2886-
DEVICE_INLINE void per_row_norm(fx4& dst) {
2887-
float sum = fx4_dot(dst, dst);
2888-
// Warp reduce sum
2889-
sum = warpReduceSum(sum);
2890-
float rsqr = rsqrtf(sum / D_H);
2891-
dst = fx4_scale(dst, rsqr);
2892-
}
28932891
28942892
DEVICE_INLINE void per_head_amax(fx4& dst, float* amax) {
28952893
dst = fx4_abs(dst);
@@ -2998,5 +2996,21 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
29982996
throw std::runtime_error(
29992997
"CUDA version is older than 12.0"); // requires CUDA>=12
30002998
}
2999+
at::Tensor quantize_qkv_per_head(
3000+
at::Tensor xqkv_amax_row, // [B_T, HH]
3001+
at::Tensor xqkv, // [B_T, HH, D_H]
3002+
at::Tensor varseq_seqpos, // [B_T]
3003+
std::optional<at::Tensor> varseq_batch, // [B_T]
3004+
at::Tensor q_seqstarts, // [B+1]
3005+
at::Tensor cache_K, // [B][MAX_T][N_KVH][D_H]
3006+
at::Tensor cache_V, // [B][MAX_T][N_KVH][D_H]
3007+
at::Tensor XQ_O, // [B_T][N_H][D]
3008+
int64_t max_seq_length, // Length of the sequence
3009+
std::optional<at::Tensor> qparam_k,
3010+
std::optional<at::Tensor> qparam_v) {
3011+
throw std::runtime_error(
3012+
"CUDA version is older than 12.0"); // requires CUDA>=12
3013+
}
3014+
30013015
#endif
30023016
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)