Skip to content

Commit d33ffbf

Browse files
authored
FP16 Postop Support (#512)
- The postop support for Pure FP16 follow a partial-Spill approach to accommodate register pressure ------------------------------------------------------ AMD Internal : [SWLCGS-4198] Signed-off-by: John Alexander <joalexan_amdeng@amd.com>
1 parent 01c17d5 commit d33ffbf

13 files changed

Lines changed: 841 additions & 66 deletions

classic/aocl_gemm_f16f16f16of16.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,6 @@ aocl_gemm_f16f16f16of16(const char order,
257257
dlp_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
258258
dlp_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
259259

260-
// Post-ops are not supported for FP16 GEMM.
261-
// Check if any post-ops are specified and return error if so.
262-
if ((metadata != NULL) && (metadata->seq_length > 0)) {
263-
dlp_print_msg(" Post-ops are not supported for f16f16f16of16 gemm.",
264-
__FILE__, __LINE__);
265-
DLP_METADATA_SET_ERROR(metadata, DLP_CLSC_NOT_SUPPORTED);
266-
goto err_hndl;
267-
}
268-
269260
// Check for A-dequantization post-op (a_post_quant)
270261
if ((metadata != NULL) && (metadata->a_post_quant != NULL)) {
271262
dlp_print_msg(" A-dequantization post-op is not supported for "

classic/frame/dlp_gemm_post_ops.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ dlp_gemm_get_stor_type(DLP_TYPE pstor_type)
4141
case DLP_BF16:
4242
stor_type = DLP_BF16;
4343
break;
44+
case DLP_F16:
45+
stor_type = DLP_F16;
46+
break;
4447
case DLP_S8:
4548
stor_type = DLP_S8;
4649
break;

classic/frame/fp16fp16fp16/dlp_gemm_fp16.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ DLP_GEMV(float16, float16, float16, f16f16f16of16)
9898
post_ops_attr.b_col_sum_vec = NULL;
9999
post_ops_attr.b_col_sum_vec_s16 = NULL;
100100

101-
post_ops_attr.buf_downscale = NULL;
101+
post_ops_attr.buf_downscale = c;
102102

103103
/* Generate thrinfo objects for jc and ic loops */
104104
dlp_task_id_t thread_jc;
@@ -360,7 +360,7 @@ DLP_GEMM_5LOOP_UNIFIED(float16, float16, float16, float16, f16f16f16of16,
360360
post_ops_attr.b_sum_offset = 0;
361361
post_ops_attr.b_col_sum_vec = NULL;
362362
post_ops_attr.b_col_sum_vec_s16 = NULL;
363-
post_ops_attr.buf_downscale = NULL;
363+
post_ops_attr.buf_downscale = c;
364364

365365
dlp_task_id_t thread_jc;
366366
dlp_task_id_t thread_ic;

src/jit/amdzen/amdzen_generator.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3003,6 +3003,16 @@ jitAmdZenFP16::executeKernel(dlp::kernels::kernelParams* _params)
30033003
params->nmask_fp16_avx512 =
30043004
0xFFFFFFFFu >> (numElemsPerReg - partial_elements);
30053005
}
3006+
3007+
// F32 postops mask: 16 F32 elements per ZMM
3008+
static constexpr int F32_PER_ZMM = 16;
3009+
int f32_partial = params->n_left % F32_PER_ZMM;
3010+
if (f32_partial == 0) {
3011+
params->nmask_avx512 = 0xFFFFu;
3012+
} else {
3013+
params->nmask_avx512 = static_cast<uint16_t>(
3014+
0xFFFFu >> (F32_PER_ZMM - f32_partial));
3015+
}
30063016
}
30073017

30083018
// Deploy the associated kernel based on n_left

0 commit comments

Comments
 (0)