@@ -884,184 +884,6 @@ __global__ void computeFP8QuantizeScaleRowwise(
884884 }
885885}
886886
887- template <typename SCALE, typename T_OUT, typename T_S, typename T_IN>
888- void invokeComputeScalesAndQuantizeMatrix (
889- T_OUT* output,
890- T_S* quant_ptr,
891- const T_IN* input,
892- const int64_t numel,
893- const int64_t lda,
894- const float * scale_ub,
895- bool stochastic_rounding,
896- const c10::cuda::CUDAStream stream) {
897- dim3 grid (numel / lda);
898- #ifdef USE_ROCM
899- bool use_shmem = true ;
900- #else
901- bool use_shmem = false ;
902- #endif
903- auto const shmem_size = lda * sizeof (T_IN);
904- if (shmem_size >= (48 << 10 )) {
905- cudaError_t ret;
906- #ifndef USE_ROCM
907- if (stochastic_rounding) {
908- ret = cudaFuncSetAttribute (
909- dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>,
910- cudaFuncAttributeMaxDynamicSharedMemorySize,
911- shmem_size);
912- } else {
913- ret = cudaFuncSetAttribute (
914- dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>,
915- cudaFuncAttributeMaxDynamicSharedMemorySize,
916- shmem_size);
917- }
918- use_shmem = ret == cudaSuccess;
919- #else
920- use_shmem = false ;
921- #endif
922- }
923- if (use_shmem) {
924- dim3 block (std::min ((lda + 31 ) / 32 * 32 , static_cast <int64_t >(1024 )));
925-
926- if (stochastic_rounding) {
927- at::PhiloxCudaState rng_engine_inputs;
928- auto gen = at::cuda::detail::getDefaultCUDAGenerator ();
929- std::lock_guard<std::mutex> lock (gen.mutex ());
930- rng_engine_inputs =
931- at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state (4 );
932-
933- MSLK_LAUNCH_KERNEL (
934- (dynamicQuantizeMatrixRowwiseStoc<SCALE, T_OUT, T_S, T_IN>),
935- grid,
936- block,
937- shmem_size,
938- stream,
939- output,
940- quant_ptr,
941- input,
942- numel,
943- lda,
944- scale_ub,
945- rng_engine_inputs);
946- } else {
947- MSLK_LAUNCH_KERNEL (
948- (dynamicQuantizeMatrixRowwise<SCALE, T_OUT, T_S, T_IN>),
949- grid,
950- block,
951- shmem_size,
952- stream,
953- output,
954- quant_ptr,
955- input,
956- numel,
957- lda,
958- scale_ub);
959- }
960- } else {
961- dim3 block (CTA_SIZE);
962- MSLK_LAUNCH_KERNEL (
963- (computeFP8QuantizeScaleRowwise<SCALE, T_S, T_IN>),
964- grid,
965- block,
966- 0 ,
967- stream,
968- quant_ptr,
969- input,
970- numel,
971- lda,
972- scale_ub);
973- invokeQuantizeMatrixRowwise (
974- output, quant_ptr, input, numel, lda, stochastic_rounding, stream);
975- }
976- }
977-
978- std::vector<at::Tensor> quantize_fp8_per_row (
979- at::Tensor input,
980- std::optional<at::Tensor> bs, // batch size
981- std::optional<at::Tensor> scale_ub, // scale upperbound
982- std::optional<c10::ScalarType> output_dtype, // Quantization type
983- bool stochastic_rounding) {
984- TORCH_CHECK (
985- input.dim () >= 2 ,
986- " Invalid dim. The dim of input should be greater than or equal to 2" );
987- TORCH_CHECK (
988- input.scalar_type () == torch::kBFloat16 ||
989- input.scalar_type () == torch::kFloat ||
990- input.scalar_type () == torch::kHalf ,
991- " Invalid datatype. input must be BF16, FP16 or FP32" );
992- TORCH_CHECK (
993- !stochastic_rounding || input.size (-1 ) % 4 == 0 ,
994- " input row dim must be 4's multiple when stochastic_rounding is True" );
995- // Default data type is f8_e4m3fn.
996- c10::ScalarType quantization_type = torch_fp8_e4m3;
997- if (output_dtype.has_value ()) {
998- TORCH_CHECK (
999- (output_dtype.value () == torch_fp8_e4m3 ||
1000- output_dtype.value () == torch_fp8_e5m2),
1001- " Invalid output type, must be e4m3 or e5m2." );
1002- quantization_type = output_dtype.value ();
1003- }
1004- std::vector<long int > quantized_input_shape;
1005- for (int i = 0 ; i < input.dim (); i++)
1006- quantized_input_shape.push_back (input.size (i));
1007- std::vector<int64_t > scale_shape;
1008- for (int i = 0 ; i < input.dim () - 1 ; i++)
1009- scale_shape.push_back (input.size (i));
1010-
1011- input = input.cuda ();
1012- at::Tensor quantized_input = torch::empty (
1013- quantized_input_shape,
1014- torch::dtype (quantization_type)
1015- .device (torch::kCUDA , at::cuda::current_device ())
1016- .requires_grad (false ));
1017- at::Tensor scales = torch::empty (
1018- scale_shape,
1019- torch::dtype (torch::kFloat32 )
1020- .device (torch::kCUDA , at::cuda::current_device ())
1021- .requires_grad (false ));
1022-
1023- if (input.numel () == 0 ) {
1024- return std::vector<at::Tensor>{quantized_input, scales};
1025- }
1026-
1027- // Templatize implementation based on output type.
1028- if (quantization_type == torch_fp8_e4m3) {
1029- auto * const quantized_input_ptr =
1030- reinterpret_cast <__nv_fp8_e4m3*>(quantized_input.data_ptr ());
1031- const auto stream = at::cuda::getCurrentCUDAStream ();
1032- invokeComputeScalesAndQuantizeMatrix<FP8_E4M3_MAX>(
1033- quantized_input_ptr,
1034- reinterpret_cast <float *>(scales.data_ptr ()),
1035- reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1036- input.numel (),
1037- input.size (-1 ),
1038- scale_ub.has_value ()
1039- ? reinterpret_cast <float *>(scale_ub.value ().data_ptr ())
1040- : nullptr ,
1041- stochastic_rounding,
1042- stream);
1043-
1044- return std::vector<at::Tensor>{quantized_input, scales};
1045- } else {
1046- auto * const quantized_input_ptr =
1047- reinterpret_cast <__nv_fp8_e5m2*>(quantized_input.data_ptr ());
1048- const auto stream = at::cuda::getCurrentCUDAStream ();
1049- invokeComputeScalesAndQuantizeMatrix<FP8_E5M2_MAX>(
1050- quantized_input_ptr,
1051- reinterpret_cast <float *>(scales.data_ptr ()),
1052- reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1053- input.numel (),
1054- input.size (-1 ),
1055- scale_ub.has_value ()
1056- ? reinterpret_cast <float *>(scale_ub.value ().data_ptr ())
1057- : nullptr ,
1058- stochastic_rounding,
1059- stream);
1060-
1061- return std::vector<at::Tensor>{quantized_input, scales};
1062- }
1063- }
1064-
1065887#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
1066888
1067889#ifdef __CUDA_ARCH__
0 commit comments