Skip to content

Commit 3a43dd1

Browse files
authored
FP16oF32 Support
Added Support for oF32 for F16 API AMD-Internal : [SWLCSG-4198] -------------------------------------------------------- Signed-off-by: John Alexander <joalexan_amdeng@amd.com>
1 parent 08ea459 commit 3a43dd1

28 files changed

Lines changed: 2421 additions & 552 deletions

classic/aocl_gemm_f16f16f16of16.c

Lines changed: 22 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
*/
2828

2929
#include "aocl_dlp_gemm_check.h"
30+
#include "classic/aocl_fp16_convert.h"
3031
#include "classic/aocl_gemm_interface_apis.h"
3132
#include "classic/aocl_lib_interface_apis.h"
3233
#include "classic/dlp_errors.h"
@@ -40,164 +41,33 @@
4041
#include "runtime/dlp_runtime.h"
4142
#include "threading/dlp_gemm_thread_decor_openmp.h"
4243

43-
#if defined(__F16C__) && defined(__GNUC__)
44-
#include <immintrin.h>
45-
#endif
46-
47-
/**
48-
* @brief Convert float32 to float16
49-
*
50-
* Uses compiler intrinsics when available (_cvtss_sh with F16C),
51-
* otherwise falls back to portable software bit manipulation.
52-
* Uses round-to-nearest-even rounding mode per IEEE-754.
53-
*
54-
* The software fallback correctly handles:
55-
* - Round-to-nearest-even rounding
56-
* - NaN propagation (preserves quiet NaN)
57-
* - Subnormal denormalization with proper rounding
58-
* - Overflow to infinity
59-
* - Underflow to zero or subnormal
60-
* - Rounding-induced exponent increment
61-
*/
62-
static inline float16
63-
f32_to_fp16(float f32_val)
64-
{
65-
#if defined(__F16C__) && defined(__GNUC__)
66-
/* Use F16C intrinsic for hardware conversion */
67-
return (float16)_cvtss_sh(f32_val, 0);
68-
#else
69-
/* Software conversion from float32 to float16 with IEEE-754 rounding */
70-
union
71-
{
72-
float f;
73-
uint32_t u;
74-
} x;
75-
x.f = f32_val;
76-
77-
/* Extract components */
78-
uint32_t sign = (x.u & 0x80000000U) >> 16; /* Bit 31 → 15 */
79-
int32_t exp32 = ((x.u & 0x7F800000U) >> 23); /* Extract exponent */
80-
uint32_t mant32 = (x.u & 0x007FFFFFU); /* Extract mantissa */
81-
82-
/* Special case: FP32 zero or subnormal */
83-
if (exp32 == 0) {
84-
/* FP32 subnormals are too small for FP16, flush to signed zero */
85-
return (float16)(sign);
86-
}
87-
88-
/* Special case: FP32 infinity or NaN */
89-
if (exp32 == 0xFF) {
90-
if (mant32 == 0) {
91-
/* Infinity */
92-
return (float16)(sign | 0x7C00U);
93-
} else {
94-
/* NaN: preserve some mantissa bits, ensure quiet NaN */
95-
uint16_t mant16 = (uint16_t)((mant32 >> 13) | 0x0200U);
96-
return (float16)(sign | 0x7C00U | (mant16 & 0x03FFU));
97-
}
98-
}
99-
100-
/* Rebias exponent: FP32 bias=127, FP16 bias=15 */
101-
int32_t exp16 = exp32 - 112; /* exp32 - 127 + 15 = exp32 - 112 */
102-
103-
/* Add implicit leading 1 to mantissa for calculations */
104-
mant32 |= 0x00800000U;
105-
106-
/* Check for underflow (handle denormals) */
107-
if (exp16 <= 0) {
108-
if (exp16 < -10) {
109-
/* Too small, flush to zero */
110-
return (float16)(sign);
111-
}
112-
113-
/*
114-
* Denormalize: shift mantissa right to align with FP16 denormal format.
115-
* For FP16 denormals, the value is: mantissa * 2^-24
116-
* We need to shift the 24-bit mantissa (with implicit 1) right by
117-
* (14 - exp32 + 127) = (141 - exp32) positions to get the 10-bit
118-
* result. This is equivalent to shifting by (1 - exp16 + 13) = (14 -
119-
* exp16).
120-
*/
121-
int total_shift = 14 - exp16; /* Total shift to get 10-bit mantissa */
122-
123-
/* Round to nearest even using the bits that will be shifted out */
124-
uint32_t round_bit = (mant32 >> (total_shift - 1)) & 1;
125-
uint32_t sticky_mask = (1U << (total_shift - 1)) - 1;
126-
uint32_t sticky = (mant32 & sticky_mask) != 0;
127-
uint32_t lsb = (mant32 >> total_shift) & 1;
128-
129-
/* Compute the shifted mantissa */
130-
uint32_t mant16 = mant32 >> total_shift;
131-
132-
/* Apply round-to-nearest-even */
133-
if (round_bit && (sticky || lsb)) {
134-
mant16++;
135-
}
136-
137-
/* Check if rounding caused normalization (overflow into bit 10) */
138-
if (mant16 >= 0x0400U) {
139-
return (float16)(sign | 0x0400U); /* Smallest normal */
140-
}
141-
142-
return (float16)(sign | (uint16_t)mant16);
143-
}
144-
145-
/* Check for overflow before rounding */
146-
if (exp16 >= 0x1F) {
147-
return (float16)(sign | 0x7C00U);
148-
}
149-
150-
/* Normal value: Round mantissa from 23 to 10 bits */
151-
uint32_t round_bits = mant32 & 0x1FFFU; /* Bits 12-0 */
152-
uint32_t lsb = (mant32 >> 13) & 1;
153-
154-
/* Round to nearest even */
155-
if (round_bits > 0x1000U || (round_bits == 0x1000U && lsb)) {
156-
mant32 += 0x1000U;
157-
}
158-
159-
/* Check for carry into exponent AFTER rounding */
160-
if (mant32 & 0x01000000U) {
161-
/* Mantissa overflowed into bit 24 */
162-
exp16++;
163-
mant32 = 0x00800000U; /* Reset to implicit 1 only */
164-
165-
/* Check if exponent overflowed to infinity */
166-
if (exp16 >= 0x1F) {
167-
return (float16)(sign | 0x7C00U);
168-
}
169-
}
170-
171-
/* Extract rounded 10-bit mantissa (remove implicit 1) */
172-
uint16_t mant16 = (uint16_t)((mant32 >> 13) & 0x03FFU);
173-
174-
return (float16)(sign | ((uint16_t)exp16 << 10) | mant16);
175-
#endif
176-
}
177-
17844
void
17945
aocl_gemm_f16f16f16of16(const char order,
18046
const char transa,
18147
const char transb,
18248
const md_t m,
18349
const md_t n,
18450
const md_t k,
185-
const float alpha,
51+
const float16 alpha,
18652
const float16* a,
18753
const md_t lda,
18854
const char mem_format_a,
18955
const float16* b,
19056
const md_t ldb,
19157
const char mem_format_b,
192-
const float beta,
58+
const float16 beta,
19359
float16* c,
19460
const md_t ldc,
19561
dlp_metadata_t* metadata)
19662
{
19763
DLP_GEMM_START_LOGGER();
64+
// alpha/beta arrive as float16 (the FP16 GEMM API contract). The shared
65+
// logger prints them as %f, so widen once at the call boundary via
66+
// fp16_to_f32. The widening is for printing only and never propagates
67+
// back into computation.
19868
DLP_GEMM_WRITE_LOGGER("f16f16f16of16", order, transa, transb, m, n, k,
199-
((float)alpha), lda, mem_format_a, ldb, mem_format_b,
200-
((float)beta), ldc, metadata);
69+
fp16_to_f32(alpha), lda, mem_format_a, ldb,
70+
mem_format_b, fp16_to_f32(beta), ldc, metadata);
20171

20272
DLP_METADATA_SET_ERROR(metadata, DLP_CLSC_SUCCESS);
20373

@@ -388,20 +258,20 @@ aocl_gemm_f16f16f16of16(const char order,
388258
AOCL_DLP_MEMORY_TAG jit_mtag_a = mtag_a_use;
389259
AOCL_DLP_MEMORY_TAG jit_mtag_b = mtag_b_use;
390260

391-
// Convert alpha and beta from float to float16 for JIT kernel.
392-
// The FP16 JIT kernel uses vpbroadcastw (16-bit broadcast) to load
393-
// alpha/beta, so we must pass FP16 addresses.
394-
float16 alpha_fp16 = f32_to_fp16(alpha);
395-
float16 beta_fp16 = f32_to_fp16(beta);
396-
397261
// Initialize DLP Plus kernel path (JIT support)
398262
lcntx_l.dlp_kernel_hndl.kernel_base = NULL;
399263

264+
// alpha/beta are passed as FP16. The decision engine reads
265+
// (void*)&alpha and (void*)&beta as float16* via
266+
// getScalingTypes<dlp::float16>, and the JIT consumes alpha natively
267+
// as FP16 (vpbroadcastw + vmulph). Beta is consumed as FP16 on the
268+
// of16 rail and widened to float by the 5-loop before each kernel
269+
// call on the of32 rail.
400270
dlp_init_and_get_kernel_hndl(
401271
DLP_KERNEL_F16F16F16OF16, order, jit_mtag_a, jit_mtag_b, m_use, n_use,
402-
k, rs_a_use, cs_a_use, rs_b_use, cs_b_use, rs_c, cs_c,
403-
(void*)&alpha_fp16, (void*)&beta_fp16, post_op_list, mr_hint, nr_hint,
404-
kc_hint, DLP_F16, &lcntx_l.dlp_kernel_hndl);
272+
k, rs_a_use, cs_a_use, rs_b_use, cs_b_use, rs_c, cs_c, (void*)&alpha,
273+
(void*)&beta, post_op_list, mr_hint, nr_hint, kc_hint, DLP_F16,
274+
&lcntx_l.dlp_kernel_hndl);
405275

406276
// FP16 is JIT-only (no intrinsic fallback), so check if JIT succeeded
407277
if (lcntx_l.dlp_kernel_hndl.kernel_base == NULL) {
@@ -417,13 +287,13 @@ aocl_gemm_f16f16f16of16(const char order,
417287
#ifdef DLP_ENABLE_OPENMP
418288
dlp_gemm_f16f16f16of16_openmp_thread_decorator(
419289
m_use, n_use, k, a_use, rs_a_use, cs_a_use, mtag_a_use, b_use, rs_b_use,
420-
cs_b_use, mtag_b_use, c, rs_c, cs_c, alpha_fp16, beta_fp16, &rntm_g,
421-
&lcntx_l, &ops, DLP_F16);
290+
cs_b_use, mtag_b_use, c, rs_c, cs_c, alpha, beta, &rntm_g, &lcntx_l,
291+
&ops, DLP_F16);
422292
#else
423293
dlp_gemm_f16f16f16of16_thread_decorator(
424294
m_use, n_use, k, a_use, rs_a_use, cs_a_use, mtag_a_use, b_use, rs_b_use,
425-
cs_b_use, mtag_b_use, c, rs_c, cs_c, alpha_fp16, beta_fp16, &rntm_g,
426-
&lcntx_l, &ops, DLP_F16);
295+
cs_b_use, mtag_b_use, c, rs_c, cs_c, alpha, beta, &rntm_g, &lcntx_l,
296+
&ops, DLP_F16);
427297
#endif
428298

429299
err_hndl:;

0 commit comments

Comments
 (0)