Skip to content

Commit 7d87aaa

Browse files
authored
optimize w4a8 decoding (PaddlePaddle#3050)
1 parent e80ea8a commit 7d87aaa

File tree

6 files changed

+253
-36
lines changed

6 files changed

+253
-36
lines changed

custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,22 +223,18 @@ class W4A8MoeGemmUniversalBase {
223223
static Status can_implement(Arguments const &args)
224224
{
225225
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()");
226-
// printf("--1\n");
227226
// Initialize static kernel and device properties, if necessary.
228227
Status result = init_device_props();
229-
// printf("--1-2\n");
230228
if (result != Status::kSuccess) {
231229
return result;
232230
}
233-
// printf("--2\n");
234231
dim3 grid = get_grid_shape(args);
235232
// printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
236233
if (!(grid.y <= std::numeric_limits<uint16_t>::max() &&
237234
grid.z <= std::numeric_limits<uint16_t>::max()))
238235
{
239236
return Status::kErrorInvalidProblem;
240237
}
241-
// printf("--3\n");
242238
return GemmKernel::can_implement(args);
243239
}
244240

@@ -285,18 +281,50 @@ class W4A8MoeGemmUniversalBase {
285281
}
286282

287283

284+
288285
/// Returns the maximum number of active thread blocks per multiprocessor
289-
static int maximum_active_blocks()
286+
static int maximum_active_blocks(int smem_capacity = -1)
290287
{
291288
CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()");
292289

293-
// Initialize static device properties, if necessary
294-
if (init_device_props() != Status::kSuccess) {
290+
int smem_size = int(sizeof(typename GemmKernel_::SharedStorage));
291+
292+
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
293+
294+
cudaError_t result;
295+
if (smem_size > (48 << 10)) {
296+
result = cudaFuncSetAttribute(Kernel2<GemmKernel_>,
297+
cudaFuncAttributeMaxDynamicSharedMemorySize,
298+
smem_size);
299+
300+
if (result != cudaSuccess) {
301+
// Call cudaGetLastError() to clear the error bit
302+
result = cudaGetLastError();
303+
CUTLASS_TRACE_HOST(
304+
" cudaFuncSetAttribute() returned error "
305+
<< cudaGetErrorString(result));
306+
return -1;
307+
}
308+
}
309+
310+
int max_active_blocks = -1;
311+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
312+
&max_active_blocks,
313+
Kernel2<GemmKernel_>,
314+
GemmKernel_::kThreadCount,
315+
smem_size);
316+
317+
if (result != cudaSuccess) {
318+
// Call cudaGetLastError() to clear the error bit
319+
result = cudaGetLastError();
320+
CUTLASS_TRACE_HOST(
321+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
322+
<< cudaGetErrorString(result));
295323
return -1;
296324
}
297325

298-
CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_);
299-
return sm_occupancy_;
326+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
327+
return max_active_blocks;
300328
}
301329

302330

@@ -341,8 +369,7 @@ class W4A8MoeGemmUniversalBase {
341369

342370
// Configure grid and block dimensions
343371
dim3 block(GemmKernel::kThreadCount, 1, 1);
344-
// dim3 grid = params_.get_grid_dims();
345-
dim3 grid(216, 1, 1);
372+
dim3 grid(params_.threadblock_count, 1, 1);
346373

347374
// Launch kernel
348375
CUTLASS_TRACE_HOST(" "

custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log
2121
rm -rf down_proj_8192_3584.log
2222
num_experts=8
2323

24-
for tokens_per_expert in 12
24+
for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192
2525

2626
do
2727
wait
28-
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
29-
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
28+
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
29+
CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
3030
done
3131
wait
3232
echo "#### finish ####"

custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,6 @@ int main(int argc, char *argv[]) {
996996
CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64,
997997
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
998998
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
999-
CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64,
1000999
};
10011000
std::vector<SplitKStyle> all_split_k_style{SplitKStyle::NO_SPLIT_K};
10021001

0 commit comments

Comments
 (0)