1+ // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
115#include " quant_utils.h"
216
317template <typename T, int VecSize>
@@ -15,7 +29,8 @@ __device__ void BlockLoad(const phi::bfloat16* X,
1529 size_t K) {
1630 for (size_t i = 0 ; i < 4 ; i++) {
1731 size_t off_m = static_cast <size_t >(blockIdx .x ) * 128 + threadIdx .y + i * 32 ;
18- size_t off_k = static_cast <size_t >(blockIdx .y ) * 128 + threadIdx .x * VecSize;
32+ size_t off_k =
33+ static_cast <size_t >(blockIdx .y ) * 128 + threadIdx .x * VecSize;
1934 size_t offset = off_m * K + off_k;
2035
2136 for (size_t j = 0 ; j < 4 ; j += VecSize) {
@@ -45,15 +60,18 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
4560
4661 // Reduce [(32), 32, 4] => [32, 4]
4762 for (int i = 0 ; i < 4 ; i++) {
48- shm[static_cast <size_t >(threadIdx .y ) * 128 + i * 32 + threadIdx .x ] = warp_max[i];
63+ shm[static_cast <size_t >(threadIdx .y ) * 128 + i * 32 + threadIdx .x ] =
64+ warp_max[i];
4965 }
5066 __syncthreads ();
5167 for (int offset = 16 ; offset > 0 ; offset /= 2 ) {
5268 if (threadIdx .y < offset) {
5369 for (int i = 0 ; i < 4 ; i++) {
5470 shm[static_cast <size_t >(threadIdx .y ) * 128 + i * 32 + threadIdx .x ] =
55- __hmax (shm[static_cast <size_t >(threadIdx .y ) * 128 + i * 32 + threadIdx .x ],
56- shm[(static_cast <size_t >(threadIdx .y ) + offset) * 128 + i * 32 + threadIdx .x ]);
71+ __hmax (shm[static_cast <size_t >(threadIdx .y ) * 128 + i * 32 +
72+ threadIdx .x ],
73+ shm[(static_cast <size_t >(threadIdx .y ) + offset) * 128 +
74+ i * 32 + threadIdx .x ]);
5775 }
5876 }
5977 __syncthreads ();
@@ -79,7 +97,8 @@ __device__ void BlockStoreScale(float* scale,
7997 }
8098 if (threadIdx .y == 0 ) {
8199 size_t idx_m = blockIdx .x - off_m / 128 ;
82- size_t off_k = static_cast <size_t >(blockIdx .y ) * 128 + threadIdx .x * VecSize;
100+ size_t off_k =
101+ static_cast <size_t >(blockIdx .y ) * 128 + threadIdx .x * VecSize;
83102 size_t offset = idx_m * K + off_k;
84103
85104 for (size_t j = 0 ; j < 4 ; j += VecSize) {
@@ -111,7 +130,8 @@ __device__ void BlockStoreOut(OutT* out,
111130 using StoreT = VecType<OutT, VecSize>;
112131 StoreT data;
113132 for (int j = 0 ; j < VecSize; j++) {
114- data[j] = shm[i * 32 + threadIdx .y ][static_cast <size_t >(threadIdx .x ) * 4 + j];
133+ data[j] =
134+ shm[i * 32 + threadIdx .y ][static_cast <size_t >(threadIdx .x ) * 4 + j];
115135 }
116136 *reinterpret_cast <StoreT*>(out + idx) = data;
117137 }
@@ -176,8 +196,8 @@ __global__ void __launch_bounds__(1024)
176196 for (int k = 0 ; k < VecSize; k++) {
177197 float input_fp32 = static_cast <float >(input[i][j + k]);
178198 float output_scaled = input_fp32 * scale_inv[j + k];
179- shm[static_cast <size_t >(threadIdx .x ) * VecSize + j * 32 + k][i * 32 + threadIdx . y ] =
180- static_cast <OutT>(output_scaled);
199+ shm[static_cast <size_t >(threadIdx .x ) * VecSize + j * 32 + k]
200+ [i * 32 + threadIdx . y ] = static_cast <OutT>(output_scaled);
181201 }
182202 }
183203 }
0 commit comments