Skip to content

Commit a17230e

Browse files
authored
[Feature] Support E8M0 related type conversion and vectorized cast (tile-ai#1731)
* [Feature] Support E8M0 related vectorized cast * fix * address comments
1 parent 5748841 commit a17230e

File tree

4 files changed

+134
-1
lines changed

4 files changed

+134
-1
lines changed

src/target/codegen_cuda.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
11831183
os << sret;
11841184
};
11851185

1186+
// A list of casting functions that are supported by TileLang templates.
1187+
// To add a new type conversion, you should do the following things:
1188+
// 1. Add the new conversion function in tl_templates. (__tl_cvt_xx)
1189+
// 2. Add a new if statement like the one below.
1190+
// 3. In src/target/utils.cc, allow this vectorizable cast.
1191+
11861192
// Handle conversion from float16 to float32
11871193
if (from_ty.is_float16() && target_ty.is_float() && target_ty.bits() == 32) {
11881194
// Use __half22float2 for vectorized conversion (half2 -> float2)
@@ -1251,6 +1257,53 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
12511257
}
12521258
}
12531259

1260+
// Handle conversion from float8 (E8M0) to bfloat16
1261+
if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16()) {
1262+
// Use __tl_cvt_e8m0x2_to_bfloat162 for vectorized conversion (fp8_e8m0x2 ->
1263+
// bfloat162)
1264+
if (lanes == 2 || lanes == 4 || lanes == 8) {
1265+
PrintVectorizedCast("__tl_cvt_e8m0x2_to_bfloat162",
1266+
"__nv_fp8x2_storage_t", "__nv_bfloat162", "", true,
1267+
false);
1268+
return;
1269+
}
1270+
}
1271+
1272+
// Handle conversion from bfloat16 to float8 (E8M0)
1273+
if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu()) {
1274+
// Use __tl_cvt_bfloat162_to_e8m0x2 for vectorized conversion (bfloat162 ->
1275+
// fp8_e8m0x2)
1276+
if (lanes == 2 || lanes == 4 || lanes == 8) {
1277+
PrintVectorizedCast("__tl_cvt_bfloat162_to_e8m0x2", "__nv_bfloat162",
1278+
"__nv_fp8x2_storage_t", "", false, true);
1279+
return;
1280+
}
1281+
}
1282+
1283+
// Handle conversion from float to float8 (E8M0)
1284+
if (from_ty.is_float() && from_ty.bits() == 32 &&
1285+
target_ty.is_float8_e8m0fnu()) {
1286+
// Use __tl_cvt_float2_to_e8m0x2 for vectorized conversion (float2 ->
1287+
// fp8_e8m0x2)
1288+
if (lanes == 2 || lanes == 4 || lanes == 8) {
1289+
PrintVectorizedCast("__tl_cvt_float2_to_e8m0x2", "float2",
1290+
"__nv_fp8x2_storage_t", "", false, true);
1291+
return;
1292+
}
1293+
}
1294+
1295+
// Handle conversion from double to float8 (E8M0)
1296+
if (from_ty.is_float() && from_ty.bits() == 64 &&
1297+
target_ty.is_float8_e8m0fnu()) {
1298+
// Use __tl_cvt_double2_to_e8m0x2 for vectorized conversion (double2 ->
1299+
// fp8_e8m0x2)
1300+
if (lanes == 2 || lanes == 4 || lanes == 8) {
1301+
PrintVectorizedCast("__tl_cvt_double2_to_e8m0x2", "double2",
1302+
"__nv_fp8x2_storage_t", "", false, true);
1303+
return;
1304+
}
1305+
}
1306+
12541307
// Handle conversion from float16 to float4 (E2M1)
12551308
if (from_ty.is_float16() && target_ty.is_float4_e2m1fn()) {
12561309
// Use __tl_cvt_half2_to_fp4x2 for vectorized conversion (half2 -> fp4x2)

src/target/utils.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ int TargetGetWarpSize(Target target) {
153153
}
154154

155155
bool IsCudaVectorizableFP8(DataType dtype) {
156+
// NOTE: E8M0 is a special type of FP8 which is not handled here
157+
// We only handle FP8 types which can be represented with
158+
// __nv_fp8_interpretation_t here
156159
return dtype.is_float8_e4m3() || dtype.is_float8_e4m3fn() ||
157160
dtype.is_float8_e5m2();
158161
}
@@ -182,6 +185,18 @@ bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty) {
182185
if (IsCudaVectorizableFP8(from_ty) && target_ty.is_float())
183186
return true;
184187

188+
// float8 (E8M0) -> bfloat16
189+
if (from_ty.is_float8_e8m0fnu() && target_ty.is_bfloat16())
190+
return true;
191+
192+
// bfloat16 -> float8 (E8M0)
193+
if (from_ty.is_bfloat16() && target_ty.is_float8_e8m0fnu())
194+
return true;
195+
196+
// float32/double -> float8 (E8M0)
197+
if (from_ty.is_float() && target_ty.is_float8_e8m0fnu())
198+
return true;
199+
185200
// float4_e2m1fn -> float32
186201
if (from_ty.is_float4_e2m1fn() && target_ty.is_float())
187202
return true;

src/tl_templates/cuda/cuda_fp8.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,58 @@ __tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x,
312312
result.y = (float)tmp.y;
313313
return result;
314314
}
315+
316+
// ============================================================================
317+
// FP8 E8M0 Related Conversions
318+
// ============================================================================
319+
#if TL_HAS_FP8_E8M0
320+
321+
// fp8_e8m0 -> bfloat16
322+
TL_DEVICE __nv_bfloat16
323+
__tl_cvt_e8m0_to_bfloat16(const __nv_fp8_storage_t src) {
324+
__nv_bfloat16_raw raw = __nv_cvt_e8m0_to_bf16raw(src);
325+
return *reinterpret_cast<const __nv_bfloat16 *>(&raw);
326+
}
327+
328+
// fp8_e8m0x2 -> bfloat16x2
329+
TL_DEVICE __nv_bfloat162
330+
__tl_cvt_e8m0x2_to_bfloat162(const __nv_fp8x2_storage_t src) {
331+
__nv_bfloat162_raw raw = __nv_cvt_e8m0x2_to_bf162raw(src);
332+
return *reinterpret_cast<const __nv_bfloat162 *>(&raw);
333+
}
334+
335+
// bfloat16 -> fp8_e8m0
336+
TL_DEVICE
337+
__nv_fp8_storage_t __tl_cvt_bfloat16_to_e8m0(const __nv_bfloat16 src) {
338+
__nv_bfloat16_raw raw = *reinterpret_cast<const __nv_bfloat16_raw *>(&src);
339+
return __nv_cvt_bfloat16raw_to_e8m0(raw, __NV_SATFINITE, cudaRoundNearest);
340+
}
341+
342+
// bfloat162 -> fp8_e8m0x2
343+
TL_DEVICE __nv_fp8x2_storage_t
344+
__tl_cvt_bfloat162_to_e8m0x2(const __nv_bfloat162 src) {
345+
__nv_bfloat162_raw raw = *reinterpret_cast<const __nv_bfloat162_raw *>(&src);
346+
return __nv_cvt_bfloat162raw_to_e8m0x2(raw, __NV_SATFINITE, cudaRoundNearest);
347+
}
348+
349+
// float -> fp8_e8m0
350+
TL_DEVICE __nv_fp8_storage_t __tl_cvt_float_to_e8m0(const float src) {
351+
return __nv_cvt_float_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest);
352+
}
353+
354+
// float2 -> fp8_e8m0x2
355+
TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_float2_to_e8m0x2(const float2 src) {
356+
return __nv_cvt_float2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest);
357+
}
358+
359+
// double -> fp8_e8m0
360+
TL_DEVICE __nv_fp8_storage_t __tl_cvt_double_to_e8m0(const double src) {
361+
return __nv_cvt_double_to_e8m0(src, __NV_SATFINITE, cudaRoundNearest);
362+
}
363+
364+
// double2 -> fp8_e8m0x2
365+
TL_DEVICE __nv_fp8x2_storage_t __tl_cvt_double2_to_e8m0x2(const double2 src) {
366+
return __nv_cvt_double2_to_e8m0x2(src, __NV_SATFINITE, cudaRoundNearest);
367+
}
368+
369+
#endif

testing/python/language/test_tilelang_language_vectorized_cast.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def run_vectorized_cast(src_dtype: T.dtype, dst_dtype: T.dtype, check_str: str,
5555

5656
code = kernel.get_kernel_source()
5757
code_parallel = kernel_parallel.get_kernel_source()
58-
print(code)
5958
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype} to {dst_dtype} with {lanes=} is not vectorized!"
6059

60+
# Requires torch >= 2.8
61+
if src_dtype == T.float8_e8m0fnu or dst_dtype == T.float8_e8m0fnu:
62+
return
63+
6164
if src_dtype == T.float4_e2m1fn or dst_dtype == T.float4_e2m1fn:
6265
return
6366

@@ -106,6 +109,13 @@ def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
106109
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
107110
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
108111
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
112+
# E8M0 <-> BFloat16
113+
(T.float8_e8m0fnu, T.bfloat16, "__tl_cvt_e8m0x2_to_bfloat162", 2),
114+
(T.bfloat16, T.float8_e8m0fnu, "__tl_cvt_bfloat162_to_e8m0x2", 2),
115+
# Float -> E8M0
116+
(T.float32, T.float8_e8m0fnu, "__tl_cvt_float2_to_e8m0x2", 2),
117+
# Double -> E8M0
118+
(T.float64, T.float8_e8m0fnu, "__tl_cvt_double2_to_e8m0x2", 2),
109119
],
110120
)
111121
def test_vectorized_cast_fp8(src_dtype, dst_dtype, check_str, lanes):

0 commit comments

Comments
 (0)