Skip to content

Commit 7aa1fe5

Browse files
authored
Add test case for indexer_k_quant_and_cache (vllm-project#201)
1 parent 14631a4 commit 7aa1fe5

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

csrc/cache.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,10 @@ class indexer_k_quant_and_cache_kernel {
425425

426426
// Compute local amax
427427
float amax = 0.f;
428-
float k_vals[VEC_SIZE];
428+
scalar_t k_vals[VEC_SIZE];
429429
for (int i = 0; i < VEC_SIZE; i++) {
430-
k_vals[i] =
431-
static_cast<float>(k_[token_idx * head_dim_ + head_dim_idx + i]);
432-
amax = sycl::fmax(amax, sycl::fabs(k_vals[i]));
430+
k_vals[i] = k_[token_idx * head_dim_ + head_dim_idx + i];
431+
amax = sycl::fmax(amax, sycl::fabs(static_cast<float>(k_vals[i])));
433432
}
434433

435434
// group-level reduction (sub-group reduce max)

tests/test_indexer_k_quant_and_cache.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
QUANT_BLOCK_SIZES = [128]
1414
BLOCK_SIZES = [16]
1515
SCALE_FMTS = ["ue8m0", "fp8e4m3"]
16-
# TODO: will add back torch.bfloat16, torch.float16
16+
# TODO: will add back torch.float16
1717
# after fp8_e4m3 acc is verified
18-
DTYPES = [torch.float32]
18+
DTYPES = [torch.float32, torch.bfloat16]
1919

2020
# override pytest parameters when enable mini pytest
2121
MINI_PYTEST_PARAMS = {
@@ -57,11 +57,12 @@ def _pytorch_group_quant(
5757
original_shape = x.shape
5858
num_groups = original_shape[-1] // group_size
5959
group_shape = original_shape[:-1] + (num_groups, group_size)
60-
x_grouped = x.view(group_shape)
6160

61+
# Quantization should be done in FP32 for better accuracy.
62+
x_grouped = x.view(group_shape).float()
6263
abs_max = torch.amax(torch.abs(x_grouped), dim=-1, keepdim=False)
63-
abs_max = torch.maximum(abs_max,
64-
torch.tensor(eps, device=x.device, dtype=x.dtype))
64+
abs_max = torch.maximum(
65+
abs_max, torch.tensor(eps, device=x.device, dtype=torch.float32))
6566

6667
FP8_MAX = torch.finfo(dtype).max
6768
FP8_MIN = torch.finfo(dtype).min

0 commit comments

Comments
 (0)