11#include < stdint.h>
22
3- __device__ __forceinline__ float raw_half_to_float (uint16_t h) {
4- float out;
5- asm volatile (" cvt.f32.f16 %0, %1;" : " =f" (out) : " h" (h));
6- return out;
7- }
8-
93__device__ auto
104MatMul (__attribute__((address_space(3 ))) float *output_allocated,
115 __attribute__((address_space(3 ))) float *output_aligned,
126 const int64_t output_offsets, const int64_t output_size1,
137 const int64_t output_size2, const int64_t output_stride1,
148 const int64_t output_stride2,
15- __attribute__((address_space(3 ))) uint16_t *a_allocated,
16- __attribute__((address_space(3 ))) uint16_t *a_aligned,
9+ __attribute__((address_space(3 ))) __fp16 *a_allocated,
10+ __attribute__((address_space(3 ))) __fp16 *a_aligned,
1711 const int64_t a_offsets, const int64_t a_size1, const int64_t a_size2,
1812 const int64_t a_stride1, const int64_t a_stride2,
19- __attribute__((address_space(3 ))) uint16_t *b_allocated,
20- __attribute__((address_space(3 ))) uint16_t *b_aligned,
13+ __attribute__((address_space(3 ))) __fp16 *b_allocated,
14+ __attribute__((address_space(3 ))) __fp16 *b_aligned,
2115 const int64_t b_offsets, const int64_t b_size1, const int64_t b_size2,
2216 const int64_t b_stride1, const int64_t b_stride2) {
2317 const int idx = threadIdx .x ;
@@ -29,13 +23,10 @@ MatMul(__attribute__((address_space(3))) float *output_allocated,
2923 for (int i = idx; i < m * n; i += bdimx) {
3024 int row = i / n;
3125 int col = i % n;
32- float acc = 0 . 0f ;
26+ float acc = 0 ;
3327 for (int j = 0 ; j < k; j++) {
34- float a_val = raw_half_to_float (
35- a_aligned[a_offsets + row * a_stride1 + j * a_stride2]);
36- float b_val = raw_half_to_float (
37- b_aligned[b_offsets + j * b_stride1 + col * b_stride2]);
38- acc += a_val * b_val;
28+ acc += a_aligned[a_offsets + row * a_stride1 + j * a_stride2] *
29+ b_aligned[b_offsets + j * b_stride1 + col * b_stride2];
3930 }
4031 output_aligned[output_offsets + row * output_stride1 +
4132 col * output_stride2] += acc;
@@ -47,8 +38,8 @@ MatMul(__attribute__((address_space(3))) float *output_allocated,
4738 __attribute__ ((address_space(3 ))) float *allocated;
4839 __attribute__ ((address_space(3 ))) float *aligned;
4940 int64_t offsets;
50- int64_t sizes1 [2 ];
51- int64_t stride1 [2 ];
41+ int64_t sizes [2 ];
42+ int64_t strides [2 ];
5243 } r{output_allocated,
5344 output_aligned,
5445 output_offsets,
0 commit comments