Skip to content

Commit e7dee22

Browse files
authored
[Quant] update fp8 quant kernel (vllm-project#147)
* update fp8 quant kernel Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * ensure vectorization appliable Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * add ut for static quant fp8 Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * remove useless val Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * remove wrong comments Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * update ut Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * thanks copilot Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> --------- Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
1 parent 0ff40c6 commit e7dee22

9 files changed

Lines changed: 603 additions & 84 deletions

File tree

csrc/dispatch_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,12 @@
7878
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
7979
AT_DISPATCH_SWITCH( \
8080
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
81+
82+
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
83+
if (expr) { \
84+
constexpr bool const_expr = true; \
85+
__VA_ARGS__(); \
86+
} else { \
87+
constexpr bool const_expr = false; \
88+
__VA_ARGS__(); \
89+
}

csrc/ops.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ void gather_cache(
7575
std::optional<torch::Tensor> seq_starts = std::nullopt);
7676

7777
void static_scaled_fp8_quant(
78-
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale);
78+
torch::Tensor& out,
79+
torch::Tensor const& input,
80+
torch::Tensor const& scale,
81+
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
7982

8083
void dynamic_scaled_fp8_quant(
8184
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale);

0 commit comments

Comments
 (0)