Skip to content

Commit dd270fc

Browse files
momochenfacebook-github-bot
authored andcommitted
Update MSLK Triton FP8 row quantization kernel to match CUDA arithmetic and delete the C++ quantize_fp8_per_row kernel (#224)
Summary: The Triton kernel _kernel_quantize_fp8_row used different arithmetic than the CUDA quantize_fp8_per_row — reciprocal-scale-multiply (a * (MAX_FP8 / amax)) vs true-division (a / (amax / MAX_FP8)). This caused bitwise discrepancies at FP8 rounding boundaries. This diff converges the Triton kernel to match CUDA's arithmetic exactly, then removes the now-redundant C++ kernel. Conducted discrepancy test before removing fp8_quantize_discrepancy_test.py, which shows 100% bitwise parity between cuda and triton impls. Differential Revision: D96502922
1 parent 133d5fc commit dd270fc

File tree

7 files changed

+15
-596
lines changed

7 files changed

+15
-596
lines changed

ci/scripts/mslk_build.bash

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,6 @@ __verify_library_symbols () {
597597
mslk::gemm::f8f8bf16_rowwise
598598
mslk::kv_cache::rope_qkv_decoding
599599
mslk::moe::index_shuffling_torch
600-
mslk::quantize::quantize_fp8_per_row
601600
)
602601
fi
603602

csrc/quantize/quantize.cu

Lines changed: 0 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -884,184 +884,6 @@ __global__ void computeFP8QuantizeScaleRowwise(
884884
}
885885
}
886886

887-
template <typename SCALE, typename T_OUT, typename T_S, typename T_IN>
888-
void invokeComputeScalesAndQuantizeMatrix(
889-
T_OUT* output,
890-
T_S* quant_ptr,
891-
const T_IN* input,
892-
const int64_t numel,
893-
const int64_t lda,
894-
const float* scale_ub,
895-
bool stochastic_rounding,
896-
const c10::cuda::CUDAStream stream) {
897-
dim3 grid(numel / lda);
898-
#ifdef USE_ROCM
899-
bool use_shmem = true;
900-
#else
901-
bool use_shmem = false;
902-
#endif
903-
auto const shmem_size = lda * sizeof(T_IN);
904-
if (shmem_size >= (48 << 10)) {
905-
cudaError_t ret;
906-
#ifndef USE_ROCM
907-
if (stochastic_rounding) {
908-
ret = cudaFuncSetAttribute(
909-
dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>,
910-
cudaFuncAttributeMaxDynamicSharedMemorySize,
911-
shmem_size);
912-
} else {
913-
ret = cudaFuncSetAttribute(
914-
dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>,
915-
cudaFuncAttributeMaxDynamicSharedMemorySize,
916-
shmem_size);
917-
}
918-
use_shmem = ret == cudaSuccess;
919-
#else
920-
use_shmem = false;
921-
#endif
922-
}
923-
if (use_shmem) {
924-
dim3 block(std::min((lda + 31) / 32 * 32, static_cast<int64_t>(1024)));
925-
926-
if (stochastic_rounding) {
927-
at::PhiloxCudaState rng_engine_inputs;
928-
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
929-
std::lock_guard<std::mutex> lock(gen.mutex());
930-
rng_engine_inputs =
931-
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(4);
932-
933-
MSLK_LAUNCH_KERNEL(
934-
(dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>),
935-
grid,
936-
block,
937-
shmem_size,
938-
stream,
939-
output,
940-
quant_ptr,
941-
input,
942-
numel,
943-
lda,
944-
scale_ub,
945-
rng_engine_inputs);
946-
} else {
947-
MSLK_LAUNCH_KERNEL(
948-
(dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>),
949-
grid,
950-
block,
951-
shmem_size,
952-
stream,
953-
output,
954-
quant_ptr,
955-
input,
956-
numel,
957-
lda,
958-
scale_ub);
959-
}
960-
} else {
961-
dim3 block(CTA_SIZE);
962-
MSLK_LAUNCH_KERNEL(
963-
(computeFP8QuantizeScaleRowwise<SCALE, T_S, T_IN>),
964-
grid,
965-
block,
966-
0,
967-
stream,
968-
quant_ptr,
969-
input,
970-
numel,
971-
lda,
972-
scale_ub);
973-
invokeQuantizeMatrixRowwise(
974-
output, quant_ptr, input, numel, lda, stochastic_rounding, stream);
975-
}
976-
}
977-
978-
std::vector<at::Tensor> quantize_fp8_per_row(
979-
at::Tensor input,
980-
std::optional<at::Tensor> bs, // batch size
981-
std::optional<at::Tensor> scale_ub, // scale upperbound
982-
std::optional<c10::ScalarType> output_dtype, // Quantization type
983-
bool stochastic_rounding) {
984-
TORCH_CHECK(
985-
input.dim() >= 2,
986-
"Invalid dim. The dim of input should be greater than or equal to 2");
987-
TORCH_CHECK(
988-
input.scalar_type() == torch::kBFloat16 ||
989-
input.scalar_type() == torch::kFloat ||
990-
input.scalar_type() == torch::kHalf,
991-
"Invalid datatype. input must be BF16, FP16 or FP32");
992-
TORCH_CHECK(
993-
!stochastic_rounding || input.size(-1) % 4 == 0,
994-
"input row dim must be 4's multiple when stochastic_rounding is True");
995-
// Default data type is f8_e4m3fn.
996-
c10::ScalarType quantization_type = torch_fp8_e4m3;
997-
if (output_dtype.has_value()) {
998-
TORCH_CHECK(
999-
(output_dtype.value() == torch_fp8_e4m3 ||
1000-
output_dtype.value() == torch_fp8_e5m2),
1001-
"Invalid output type, must be e4m3 or e5m2.");
1002-
quantization_type = output_dtype.value();
1003-
}
1004-
std::vector<long int> quantized_input_shape;
1005-
for (int i = 0; i < input.dim(); i++)
1006-
quantized_input_shape.push_back(input.size(i));
1007-
std::vector<int64_t> scale_shape;
1008-
for (int i = 0; i < input.dim() - 1; i++)
1009-
scale_shape.push_back(input.size(i));
1010-
1011-
input = input.cuda();
1012-
at::Tensor quantized_input = torch::empty(
1013-
quantized_input_shape,
1014-
torch::dtype(quantization_type)
1015-
.device(torch::kCUDA, at::cuda::current_device())
1016-
.requires_grad(false));
1017-
at::Tensor scales = torch::empty(
1018-
scale_shape,
1019-
torch::dtype(torch::kFloat32)
1020-
.device(torch::kCUDA, at::cuda::current_device())
1021-
.requires_grad(false));
1022-
1023-
if (input.numel() == 0) {
1024-
return std::vector<at::Tensor>{quantized_input, scales};
1025-
}
1026-
1027-
// Templatize implementation based on output type.
1028-
if (quantization_type == torch_fp8_e4m3) {
1029-
auto* const quantized_input_ptr =
1030-
reinterpret_cast<__nv_fp8_e4m3*>(quantized_input.data_ptr());
1031-
const auto stream = at::cuda::getCurrentCUDAStream();
1032-
invokeComputeScalesAndQuantizeMatrix<FP8_E4M3_MAX>(
1033-
quantized_input_ptr,
1034-
reinterpret_cast<float*>(scales.data_ptr()),
1035-
reinterpret_cast<const __nv_bfloat16*>(input.data_ptr()),
1036-
input.numel(),
1037-
input.size(-1),
1038-
scale_ub.has_value()
1039-
? reinterpret_cast<float*>(scale_ub.value().data_ptr())
1040-
: nullptr,
1041-
stochastic_rounding,
1042-
stream);
1043-
1044-
return std::vector<at::Tensor>{quantized_input, scales};
1045-
} else {
1046-
auto* const quantized_input_ptr =
1047-
reinterpret_cast<__nv_fp8_e5m2*>(quantized_input.data_ptr());
1048-
const auto stream = at::cuda::getCurrentCUDAStream();
1049-
invokeComputeScalesAndQuantizeMatrix<FP8_E5M2_MAX>(
1050-
quantized_input_ptr,
1051-
reinterpret_cast<float*>(scales.data_ptr()),
1052-
reinterpret_cast<const __nv_bfloat16*>(input.data_ptr()),
1053-
input.numel(),
1054-
input.size(-1),
1055-
scale_ub.has_value()
1056-
? reinterpret_cast<float*>(scale_ub.value().data_ptr())
1057-
: nullptr,
1058-
stochastic_rounding,
1059-
stream);
1060-
1061-
return std::vector<at::Tensor>{quantized_input, scales};
1062-
}
1063-
}
1064-
1065887
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
1066888

1067889
#ifdef __CUDA_ARCH__

csrc/quantize/quantize_ops.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ TORCH_LIBRARY_FRAGMENT(mslk, m) {
2323

2424
m.def(
2525
"quantize_fp8_per_tensor(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, bool stochastic_rounding=False) -> Tensor[]");
26-
m.def(
27-
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None, bool stochastic_rounding = False) -> Tensor[] ");
28-
2926
m.def(
3027
"get_fp8_per_tensor_scale(Tensor input, Tensor? bs=None, Tensor? scale_ub=None) -> Tensor");
3128

@@ -41,7 +38,6 @@ TORCH_LIBRARY_FRAGMENT(mslk, m) {
4138

4239
TORCH_LIBRARY_IMPL(mslk, CUDA, m) {
4340
DISPATCH_TO_CUDA("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
44-
DISPATCH_TO_CUDA("quantize_fp8_per_row", quantize_fp8_per_row);
4541
DISPATCH_TO_CUDA("per_tensor_quantize_i8", per_tensor_quantize_i8);
4642
DISPATCH_TO_CUDA(
4743
"per_tensor_dynamic_quantize_i8", per_tensor_dynamic_quantize_i8);
@@ -74,30 +70,8 @@ std::vector<at::Tensor> quantize_fp8_per_tensor_meta(
7470
return {Y, scale};
7571
}
7672

77-
std::vector<at::Tensor> quantize_fp8_per_row_meta(
78-
at::Tensor input,
79-
std::optional<at::Tensor> /* bs */,
80-
std::optional<at::Tensor> /* scale_ub */,
81-
std::optional<c10::ScalarType> /* output_dtype */,
82-
bool /* stochastic_rounding */) {
83-
int dims = input.dim();
84-
TORCH_CHECK(dims == 2 || dims == 3, "The dim of input should be 2 or 3");
85-
at::Tensor Y = at::empty_like(input, input.options().dtype(torch_fp8_e4m3));
86-
at::Tensor scale;
87-
if (dims == 2) {
88-
const at::SymInt M = input.sym_size(0);
89-
scale = at::empty_symint({M}, input.options().dtype(at::kFloat));
90-
} else {
91-
const at::SymInt B = input.sym_size(0);
92-
const at::SymInt M = input.sym_size(1);
93-
scale = at::empty_symint({B, M}, input.options().dtype(at::kFloat));
94-
}
95-
return {Y, scale};
96-
}
97-
9873
TORCH_LIBRARY_IMPL(mslk, Meta, m) {
9974
DISPATCH_TO_META("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta);
100-
DISPATCH_TO_META("quantize_fp8_per_row", quantize_fp8_per_row_meta);
10175
}
10276

10377
} // namespace mslk::quantize

include/mslk/quantize/quantize.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ std::vector<at::Tensor> quantize_fp8_per_tensor(
2121
std::optional<at::Tensor> scale_ub, // scale upperbound
2222
const bool stochastic_rounding); // whether apply stochastic rounding
2323

24-
std::vector<at::Tensor> quantize_fp8_per_row(
25-
at::Tensor input,
26-
std::optional<at::Tensor> bs, // batch size
27-
std::optional<at::Tensor> scale_ub, // scale upperbound
28-
std::optional<c10::ScalarType> output_dtype, // output dtype
29-
bool stochastic_rounding); // whether apply stochastic rounding
30-
3124
at::Tensor quantize_fp8_per_tensor_fixed_scale(
3225
at::Tensor input,
3326
at::Tensor scale,

mslk/quantize/triton/fp8_quantize.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ def _kernel_quantize_fp8_row(
155155
else:
156156
cur_max = tl.maximum(cur_max, EPS)
157157

158-
# Scale and quantize.
159-
a_scale = MAX_FP8 / cur_max
160-
tl.store(A_scale + pid, 1.0 / a_scale)
158+
# Scale and quantize
159+
scale = tl.div_rn(cur_max, MAX_FP8)
160+
tl.store(A_scale + pid, scale)
161161

162-
a_fp8 = a * a_scale
162+
a_fp8 = tl.div_rn(a.to(tl.float32), scale)
163163
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
164164
tl.store(
165165
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
@@ -188,9 +188,9 @@ def _kernel_quantize_fp8_row(
188188
else:
189189
cur_max = tl.maximum(cur_max, EPS)
190190

191-
# Scale and quantize.
192-
a_scale = MAX_FP8 / cur_max
193-
tl.store(A_scale + pid, 1.0 / a_scale)
191+
# Scale and quantize
192+
scale = tl.div_rn(cur_max, MAX_FP8)
193+
tl.store(A_scale + pid, scale)
194194

195195
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
196196
n_offset = tl.arange(0, BLOCK_SIZE)
@@ -202,7 +202,7 @@ def _kernel_quantize_fp8_row(
202202
other=0.0,
203203
)
204204
# For elements >= K, a will be 0
205-
a_fp8 = a * a_scale
205+
a_fp8 = tl.div_rn(a.to(tl.float32), scale)
206206
# Clamp A to fp8 range to make sure there's no overflow.
207207
# This is required for AMD. Nvidia's default saturation
208208
# handles it, but it's nice to have anyway.
@@ -222,7 +222,7 @@ def triton_quantize_fp8_row(
222222
scale_ub: Optional[torch.Tensor] = None,
223223
zero_start_index_M: Optional[torch.Tensor] = None,
224224
align_rows_to: Optional[int] = None,
225-
eps_opt: Optional[float] = None,
225+
eps_opt: Optional[float] = 1.0 / 512.0,
226226
) -> tuple[torch.Tensor, torch.Tensor]:
227227
"""
228228
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -232,10 +232,10 @@ def triton_quantize_fp8_row(
232232
scale_ub (Tensor): Maximum allowed value for scale.
233233
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
234234
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
235-
eps_opt: Lower bound for amax. If provided, amax will be clamped to this value.
235+
eps_opt: Lower bound for amax (default 1/512 to match CUDA min_scaling_factor).
236236
Returns:
237237
torch.Tensor: fp8 scaled tensor.
238-
torch.Tensor: reciprocal scale tensor per row.
238+
torch.Tensor: scale tensor per row (scale = amax / MAX_FP8).
239239
"""
240240
if scale_ub is not None and scale_ub.device != a.device:
241241
raise Exception("'scale_ub' must be on the same device as 'a'")
@@ -693,7 +693,7 @@ def quantize_fp8_row(
693693
use_triton: bool = True,
694694
output_device: Optional[torch.device] = None,
695695
align_rows_to: Optional[int] = None,
696-
eps_opt: Optional[float] = None,
696+
eps_opt: Optional[float] = 1.0 / 512.0,
697697
) -> tuple[torch.Tensor, torch.Tensor]:
698698
"""
699699
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -705,10 +705,10 @@ def quantize_fp8_row(
705705
use_triton (bool): Whether to use triton kernel or pytorch.
706706
output_device (torch.device): Device to optionally move the scaled tensors to.
707707
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
708-
eps_opt: Lower bound for amax. If amax is below this value, it will be clamped to this value.
708+
eps_opt: Lower bound for amax (default 1/512 to match CUDA min_scaling_factor).
709709
Returns:
710710
torch.Tensor: fp8 scaled tensor.
711-
torch.Tensor: The reciprocal scale tensor per row.
711+
torch.Tensor: scale tensor per row (scale = amax / MAX_FP8).
712712
"""
713713

714714
if a.device == torch.device("cpu"):

test/gemm/gemm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def f(
553553
def f(
554554
x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor]
555555
) -> torch.Tensor:
556-
xq, x_scale = torch.ops.mslk.quantize_fp8_per_row(x, output_dtype=QType)
556+
xq, x_scale = quantize_fp8_row(x)
557557
wq, w_scale = quantize_fp8_row(w)
558558
if UseTriton and torch.version.cuda:
559559
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)

0 commit comments

Comments
 (0)