@@ -823,58 +823,10 @@ int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m,
823
823
return max_grid_m * max_grid_n * split_k_limit * 4 ;
824
824
}
825
825
826
- // =============================== Specialization T == WeightType
827
- // =======================================
828
- template <typename WeightType>
829
- void CutlassFpAIntBGemmRunner<float , WeightType>::gemm_bias_act(
830
- const float * A,
831
- const WeightType* B,
832
- const float * weight_scales,
833
- const float * biases,
834
- float * C,
835
- int m,
836
- int n,
837
- int k,
838
- int group_size,
839
- std::string activation_type,
840
- char * workspace_ptr,
841
- const size_t workspace_bytes,
842
- cudaStream_t stream) {
843
- throw std::runtime_error (
844
- (" Attempting to run mixed gemm bias act when the types are the same is "
845
- " an error." ));
846
- }
847
-
848
- template <typename WeightType>
849
- void CutlassFpAIntBGemmRunner<float , WeightType>::gemm(
850
- const float * A,
851
- const WeightType* B,
852
- const float * weight_scales,
853
- float * C,
854
- int m,
855
- int n,
856
- int k,
857
- int group_size,
858
- char * workspace_ptr,
859
- const size_t workspace_bytes,
860
- cudaStream_t stream) {
861
- throw std::runtime_error ((
862
- " Attempting to run mixed gemm when the types are the same is an error." ));
863
- }
864
-
865
- template <typename WeightType>
866
- int CutlassFpAIntBGemmRunner<float , WeightType>::getWorkspaceSize(const int m,
867
- const int n,
868
- const int k) {
869
- return 0 ;
870
- }
871
-
872
- template class CutlassFpAIntBGemmRunner <float , uint8_t >;
873
826
template class CutlassFpAIntBGemmRunner <half, uint8_t >;
874
827
#ifdef PADDLE_CUDA_BF16
875
828
template class CutlassFpAIntBGemmRunner <__nv_bfloat16, uint8_t >;
876
829
#endif
877
- template class CutlassFpAIntBGemmRunner <float , cutlass::uint4b_t >;
878
830
template class CutlassFpAIntBGemmRunner <half, cutlass::uint4b_t >;
879
831
#ifdef PADDLE_CUDA_BF16
880
832
template class CutlassFpAIntBGemmRunner <__nv_bfloat16, cutlass::uint4b_t >;
0 commit comments