2727 */
2828
2929#include "aocl_dlp_gemm_check.h"
30+ #include "classic/aocl_fp16_convert.h"
3031#include "classic/aocl_gemm_interface_apis.h"
3132#include "classic/aocl_lib_interface_apis.h"
3233#include "classic/dlp_errors.h"
4041#include "runtime/dlp_runtime.h"
4142#include "threading/dlp_gemm_thread_decor_openmp.h"
4243
43- #if defined(__F16C__ ) && defined(__GNUC__ )
44- #include <immintrin.h>
45- #endif
46-
47- /**
48- * @brief Convert float32 to float16
49- *
50- * Uses compiler intrinsics when available (_cvtss_sh with F16C),
51- * otherwise falls back to portable software bit manipulation.
52- * Uses round-to-nearest-even rounding mode per IEEE-754.
53- *
54- * The software fallback correctly handles:
55- * - Round-to-nearest-even rounding
56- * - NaN propagation (preserves quiet NaN)
57- * - Subnormal denormalization with proper rounding
58- * - Overflow to infinity
59- * - Underflow to zero or subnormal
60- * - Rounding-induced exponent increment
61- */
62- static inline float16
63- f32_to_fp16 (float f32_val )
64- {
65- #if defined(__F16C__ ) && defined(__GNUC__ )
66- /* Use F16C intrinsic for hardware conversion */
67- return (float16 )_cvtss_sh (f32_val , 0 );
68- #else
69- /* Software conversion from float32 to float16 with IEEE-754 rounding */
70- union
71- {
72- float f ;
73- uint32_t u ;
74- } x ;
75- x .f = f32_val ;
76-
77- /* Extract components */
78- uint32_t sign = (x .u & 0x80000000U ) >> 16 ; /* Bit 31 → 15 */
79- int32_t exp32 = ((x .u & 0x7F800000U ) >> 23 ); /* Extract exponent */
80- uint32_t mant32 = (x .u & 0x007FFFFFU ); /* Extract mantissa */
81-
82- /* Special case: FP32 zero or subnormal */
83- if (exp32 == 0 ) {
84- /* FP32 subnormals are too small for FP16, flush to signed zero */
85- return (float16 )(sign );
86- }
87-
88- /* Special case: FP32 infinity or NaN */
89- if (exp32 == 0xFF ) {
90- if (mant32 == 0 ) {
91- /* Infinity */
92- return (float16 )(sign | 0x7C00U );
93- } else {
94- /* NaN: preserve some mantissa bits, ensure quiet NaN */
95- uint16_t mant16 = (uint16_t )((mant32 >> 13 ) | 0x0200U );
96- return (float16 )(sign | 0x7C00U | (mant16 & 0x03FFU ));
97- }
98- }
99-
100- /* Rebias exponent: FP32 bias=127, FP16 bias=15 */
101- int32_t exp16 = exp32 - 112 ; /* exp32 - 127 + 15 = exp32 - 112 */
102-
103- /* Add implicit leading 1 to mantissa for calculations */
104- mant32 |= 0x00800000U ;
105-
106- /* Check for underflow (handle denormals) */
107- if (exp16 <= 0 ) {
108- if (exp16 < -10 ) {
109- /* Too small, flush to zero */
110- return (float16 )(sign );
111- }
112-
113- /*
114- * Denormalize: shift mantissa right to align with FP16 denormal format.
115- * For FP16 denormals, the value is: mantissa * 2^-24
116- * We need to shift the 24-bit mantissa (with implicit 1) right by
117- * (14 - exp32 + 127) = (141 - exp32) positions to get the 10-bit
118- * result. This is equivalent to shifting by (1 - exp16 + 13) = (14 -
119- * exp16).
120- */
121- int total_shift = 14 - exp16 ; /* Total shift to get 10-bit mantissa */
122-
123- /* Round to nearest even using the bits that will be shifted out */
124- uint32_t round_bit = (mant32 >> (total_shift - 1 )) & 1 ;
125- uint32_t sticky_mask = (1U << (total_shift - 1 )) - 1 ;
126- uint32_t sticky = (mant32 & sticky_mask ) != 0 ;
127- uint32_t lsb = (mant32 >> total_shift ) & 1 ;
128-
129- /* Compute the shifted mantissa */
130- uint32_t mant16 = mant32 >> total_shift ;
131-
132- /* Apply round-to-nearest-even */
133- if (round_bit && (sticky || lsb )) {
134- mant16 ++ ;
135- }
136-
137- /* Check if rounding caused normalization (overflow into bit 10) */
138- if (mant16 >= 0x0400U ) {
139- return (float16 )(sign | 0x0400U ); /* Smallest normal */
140- }
141-
142- return (float16 )(sign | (uint16_t )mant16 );
143- }
144-
145- /* Check for overflow before rounding */
146- if (exp16 >= 0x1F ) {
147- return (float16 )(sign | 0x7C00U );
148- }
149-
150- /* Normal value: Round mantissa from 23 to 10 bits */
151- uint32_t round_bits = mant32 & 0x1FFFU ; /* Bits 12-0 */
152- uint32_t lsb = (mant32 >> 13 ) & 1 ;
153-
154- /* Round to nearest even */
155- if (round_bits > 0x1000U || (round_bits == 0x1000U && lsb )) {
156- mant32 += 0x1000U ;
157- }
158-
159- /* Check for carry into exponent AFTER rounding */
160- if (mant32 & 0x01000000U ) {
161- /* Mantissa overflowed into bit 24 */
162- exp16 ++ ;
163- mant32 = 0x00800000U ; /* Reset to implicit 1 only */
164-
165- /* Check if exponent overflowed to infinity */
166- if (exp16 >= 0x1F ) {
167- return (float16 )(sign | 0x7C00U );
168- }
169- }
170-
171- /* Extract rounded 10-bit mantissa (remove implicit 1) */
172- uint16_t mant16 = (uint16_t )((mant32 >> 13 ) & 0x03FFU );
173-
174- return (float16 )(sign | ((uint16_t )exp16 << 10 ) | mant16 );
175- #endif
176- }
177-
17844void
17945aocl_gemm_f16f16f16of16 (const char order ,
18046 const char transa ,
18147 const char transb ,
18248 const md_t m ,
18349 const md_t n ,
18450 const md_t k ,
185- const float alpha ,
51+ const float16 alpha ,
18652 const float16 * a ,
18753 const md_t lda ,
18854 const char mem_format_a ,
18955 const float16 * b ,
19056 const md_t ldb ,
19157 const char mem_format_b ,
192- const float beta ,
58+ const float16 beta ,
19359 float16 * c ,
19460 const md_t ldc ,
19561 dlp_metadata_t * metadata )
19662{
19763 DLP_GEMM_START_LOGGER ();
64+ // alpha/beta arrive as float16 (the FP16 GEMM API contract). The shared
65+ // logger prints them as %f, so widen once at the call boundary via
66+ // fp16_to_f32. The widening is for printing only and never propagates
67+ // back into computation.
19868 DLP_GEMM_WRITE_LOGGER ("f16f16f16of16" , order , transa , transb , m , n , k ,
199- (( float ) alpha ), lda , mem_format_a , ldb , mem_format_b ,
200- (( float ) beta ), ldc , metadata );
69+ fp16_to_f32 ( alpha ), lda , mem_format_a , ldb ,
70+ mem_format_b , fp16_to_f32 ( beta ), ldc , metadata );
20171
20272 DLP_METADATA_SET_ERROR (metadata , DLP_CLSC_SUCCESS );
20373
@@ -388,20 +258,20 @@ aocl_gemm_f16f16f16of16(const char order,
388258 AOCL_DLP_MEMORY_TAG jit_mtag_a = mtag_a_use ;
389259 AOCL_DLP_MEMORY_TAG jit_mtag_b = mtag_b_use ;
390260
391- // Convert alpha and beta from float to float16 for JIT kernel.
392- // The FP16 JIT kernel uses vpbroadcastw (16-bit broadcast) to load
393- // alpha/beta, so we must pass FP16 addresses.
394- float16 alpha_fp16 = f32_to_fp16 (alpha );
395- float16 beta_fp16 = f32_to_fp16 (beta );
396-
397261 // Initialize DLP Plus kernel path (JIT support)
398262 lcntx_l .dlp_kernel_hndl .kernel_base = NULL ;
399263
264+ // alpha/beta are passed as FP16. The decision engine reads
265+ // (void*)&alpha and (void*)&beta as float16* via
266+ // getScalingTypes<dlp::float16>, and the JIT consumes alpha natively
267+ // as FP16 (vpbroadcastw + vmulph). Beta is consumed as FP16 on the
268+ // of16 rail and widened to float by the 5-loop before each kernel
269+ // call on the of32 rail.
400270 dlp_init_and_get_kernel_hndl (
401271 DLP_KERNEL_F16F16F16OF16 , order , jit_mtag_a , jit_mtag_b , m_use , n_use ,
402- k , rs_a_use , cs_a_use , rs_b_use , cs_b_use , rs_c , cs_c ,
403- (void * )& alpha_fp16 , ( void * ) & beta_fp16 , post_op_list , mr_hint , nr_hint ,
404- kc_hint , DLP_F16 , & lcntx_l .dlp_kernel_hndl );
272+ k , rs_a_use , cs_a_use , rs_b_use , cs_b_use , rs_c , cs_c , ( void * ) & alpha ,
273+ (void * )& beta , post_op_list , mr_hint , nr_hint , kc_hint , DLP_F16 ,
274+ & lcntx_l .dlp_kernel_hndl );
405275
406276 // FP16 is JIT-only (no intrinsic fallback), so check if JIT succeeded
407277 if (lcntx_l .dlp_kernel_hndl .kernel_base == NULL ) {
@@ -417,13 +287,13 @@ aocl_gemm_f16f16f16of16(const char order,
417287#ifdef DLP_ENABLE_OPENMP
418288 dlp_gemm_f16f16f16of16_openmp_thread_decorator (
419289 m_use , n_use , k , a_use , rs_a_use , cs_a_use , mtag_a_use , b_use , rs_b_use ,
420- cs_b_use , mtag_b_use , c , rs_c , cs_c , alpha_fp16 , beta_fp16 , & rntm_g ,
421- & lcntx_l , & ops , DLP_F16 );
290+ cs_b_use , mtag_b_use , c , rs_c , cs_c , alpha , beta , & rntm_g , & lcntx_l ,
291+ & ops , DLP_F16 );
422292#else
423293 dlp_gemm_f16f16f16of16_thread_decorator (
424294 m_use , n_use , k , a_use , rs_a_use , cs_a_use , mtag_a_use , b_use , rs_b_use ,
425- cs_b_use , mtag_b_use , c , rs_c , cs_c , alpha_fp16 , beta_fp16 , & rntm_g ,
426- & lcntx_l , & ops , DLP_F16 );
295+ cs_b_use , mtag_b_use , c , rs_c , cs_c , alpha , beta , & rntm_g , & lcntx_l ,
296+ & ops , DLP_F16 );
427297#endif
428298
429299err_hndl :;
0 commit comments