Skip to content

Commit 315aef4

Browse files
committed
refine
1 parent ebc6cfb commit 315aef4

3 files changed

Lines changed: 60 additions & 10 deletions

File tree

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_spaq.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "quant_utils.h"
216

317
#define LAUNCH_FUSED_SPAQ(__using_pow2_scaling, __with_prob) \
@@ -196,7 +210,8 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,
196210
const int quant_block_idx =
197211
threadIdx.x / 128; // 0 or 1, two quant blocks per block
198212
const int64_t in_y_idx = blockIdx.y;
199-
const int64_t in_x_idx = static_cast<uint64_t>(blockIdx.x) * blockDim.x + x_offset;
213+
const int64_t in_x_idx =
214+
static_cast<uint64_t>(blockIdx.x) * blockDim.x + x_offset;
200215
const int64_t src_idx = in_y_idx * cols + in_x_idx;
201216

202217
// Load data and compute swiGLU activation

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_stack_transpose_quant.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "quant_utils.h"
216

317
template <typename T, int VecSize>
@@ -199,7 +213,8 @@ __global__ void __launch_bounds__(1024)
199213
using StoreT = VecType<OutT, 4>;
200214
StoreT data;
201215
for (int j = 0; j < 4; j++) {
202-
data[j] = shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
216+
data[j] =
217+
shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
203218
}
204219
*reinterpret_cast<StoreT*>(out + idx) = data;
205220
}

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#include "quant_utils.h"
216

317
template <typename T, int VecSize>
@@ -15,7 +29,8 @@ __device__ void BlockLoad(const phi::bfloat16* X,
1529
size_t K) {
1630
for (size_t i = 0; i < 4; i++) {
1731
size_t off_m = static_cast<size_t>(blockIdx.x) * 128 + threadIdx.y + i * 32;
18-
size_t off_k = static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
32+
size_t off_k =
33+
static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
1934
size_t offset = off_m * K + off_k;
2035

2136
for (size_t j = 0; j < 4; j += VecSize) {
@@ -45,15 +60,18 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
4560

4661
// Reduce [(32), 32, 4] => [32, 4]
4762
for (int i = 0; i < 4; i++) {
48-
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] = warp_max[i];
63+
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] =
64+
warp_max[i];
4965
}
5066
__syncthreads();
5167
for (int offset = 16; offset > 0; offset /= 2) {
5268
if (threadIdx.y < offset) {
5369
for (int i = 0; i < 4; i++) {
5470
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] =
55-
__hmax(shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x],
56-
shm[(static_cast<size_t>(threadIdx.y) + offset) * 128 + i * 32 + threadIdx.x]);
71+
__hmax(shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 +
72+
threadIdx.x],
73+
shm[(static_cast<size_t>(threadIdx.y) + offset) * 128 +
74+
i * 32 + threadIdx.x]);
5775
}
5876
}
5977
__syncthreads();
@@ -79,7 +97,8 @@ __device__ void BlockStoreScale(float* scale,
7997
}
8098
if (threadIdx.y == 0) {
8199
size_t idx_m = blockIdx.x - off_m / 128;
82-
size_t off_k = static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
100+
size_t off_k =
101+
static_cast<size_t>(blockIdx.y) * 128 + threadIdx.x * VecSize;
83102
size_t offset = idx_m * K + off_k;
84103

85104
for (size_t j = 0; j < 4; j += VecSize) {
@@ -111,7 +130,8 @@ __device__ void BlockStoreOut(OutT* out,
111130
using StoreT = VecType<OutT, VecSize>;
112131
StoreT data;
113132
for (int j = 0; j < VecSize; j++) {
114-
data[j] = shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
133+
data[j] =
134+
shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
115135
}
116136
*reinterpret_cast<StoreT*>(out + idx) = data;
117137
}
@@ -176,8 +196,8 @@ __global__ void __launch_bounds__(1024)
176196
for (int k = 0; k < VecSize; k++) {
177197
float input_fp32 = static_cast<float>(input[i][j + k]);
178198
float output_scaled = input_fp32 * scale_inv[j + k];
179-
shm[static_cast<size_t>(threadIdx.x) * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
180-
static_cast<OutT>(output_scaled);
199+
shm[static_cast<size_t>(threadIdx.x) * VecSize + j * 32 + k]
200+
[i * 32 + threadIdx.y] = static_cast<OutT>(output_scaled);
181201
}
182202
}
183203
}

0 commit comments

Comments
 (0)