Skip to content

Commit eb9acb3

Browse files
authored
[TLERaw] Enable fp16 for TLERaw CUDA (#471)
Signed-off-by: Jinjie Liu <jjliu@baai.ac.cn>
1 parent b1ff7fd commit eb9acb3

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

python/tutorials/tle/raw/cuda/02-fused-softmax.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <math_constants.h>
2+
23
__device__ auto
34
SoftmaxKernel(__attribute__((address_space(3))) float *output_allocated,
45
__attribute__((address_space(3))) float *output_aligned,
@@ -50,11 +51,11 @@ SoftmaxKernel(__attribute__((address_space(3))) float *output_allocated,
5051
__attribute__((address_space(3))) float *allocated;
5152
__attribute__((address_space(3))) float *aligned;
5253
int64_t offsets;
53-
int64_t sizes1[1];
54-
int64_t stride1[1];
54+
int64_t sizes[1];
55+
int64_t strides[1];
5556
} r{
56-
output_allocated, output_aligned, output_offsets,
57-
output_size, output_stride,
57+
output_allocated, output_aligned, output_offsets,
58+
{output_size}, {output_stride},
5859
};
5960
return r;
6061
}

python/tutorials/tle/raw/cuda/03-matrix-multiplication.cu

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
#include <stdint.h>
22

3-
__device__ __forceinline__ float raw_half_to_float(uint16_t h) {
4-
float out;
5-
asm volatile("cvt.f32.f16 %0, %1;" : "=f"(out) : "h"(h));
6-
return out;
7-
}
8-
93
__device__ auto
104
MatMul(__attribute__((address_space(3))) float *output_allocated,
115
__attribute__((address_space(3))) float *output_aligned,
126
const int64_t output_offsets, const int64_t output_size1,
137
const int64_t output_size2, const int64_t output_stride1,
148
const int64_t output_stride2,
15-
__attribute__((address_space(3))) uint16_t *a_allocated,
16-
__attribute__((address_space(3))) uint16_t *a_aligned,
9+
__attribute__((address_space(3))) __fp16 *a_allocated,
10+
__attribute__((address_space(3))) __fp16 *a_aligned,
1711
const int64_t a_offsets, const int64_t a_size1, const int64_t a_size2,
1812
const int64_t a_stride1, const int64_t a_stride2,
19-
__attribute__((address_space(3))) uint16_t *b_allocated,
20-
__attribute__((address_space(3))) uint16_t *b_aligned,
13+
__attribute__((address_space(3))) __fp16 *b_allocated,
14+
__attribute__((address_space(3))) __fp16 *b_aligned,
2115
const int64_t b_offsets, const int64_t b_size1, const int64_t b_size2,
2216
const int64_t b_stride1, const int64_t b_stride2) {
2317
const int idx = threadIdx.x;
@@ -29,13 +23,10 @@ MatMul(__attribute__((address_space(3))) float *output_allocated,
2923
for (int i = idx; i < m * n; i += bdimx) {
3024
int row = i / n;
3125
int col = i % n;
32-
float acc = 0.0f;
26+
float acc = 0;
3327
for (int j = 0; j < k; j++) {
34-
float a_val = raw_half_to_float(
35-
a_aligned[a_offsets + row * a_stride1 + j * a_stride2]);
36-
float b_val = raw_half_to_float(
37-
b_aligned[b_offsets + j * b_stride1 + col * b_stride2]);
38-
acc += a_val * b_val;
28+
acc += a_aligned[a_offsets + row * a_stride1 + j * a_stride2] *
29+
b_aligned[b_offsets + j * b_stride1 + col * b_stride2];
3930
}
4031
output_aligned[output_offsets + row * output_stride1 +
4132
col * output_stride2] += acc;
@@ -47,8 +38,8 @@ MatMul(__attribute__((address_space(3))) float *output_allocated,
4738
__attribute__((address_space(3))) float *allocated;
4839
__attribute__((address_space(3))) float *aligned;
4940
int64_t offsets;
50-
int64_t sizes1[2];
51-
int64_t stride1[2];
41+
int64_t sizes[2];
42+
int64_t strides[2];
5243
} r{output_allocated,
5344
output_aligned,
5445
output_offsets,

0 commit comments

Comments
 (0)