-
Notifications
You must be signed in to change notification settings - Fork 100
Add collective benchmark and correctness check #814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
4474359
add benchmakr
Binyang2014 cdb0383
update
Binyang2014 0605022
WIP
Binyang2014 b302796
WIP
Binyang2014 2fe6b1e
WIP
Binyang2014 f1a5a7d
update
Binyang2014 44dab3b
update correctness check
Binyang2014 dc37dd6
remove some code
Binyang2014 569acc3
fix issue
Binyang2014 ab567ef
add new test
Binyang2014 ad97f72
Merge branch 'main' into binyli/benchmark
Binyang2014 c8a49fa
update
Binyang2014 493e3b3
WIP
Binyang2014 ce03bae
update
Binyang2014 f830639
WIP
Binyang2014 fac9467
WIP
Binyang2014 a861ccc
WIP
Binyang2014 efb6865
WIP
Binyang2014 04c995f
WIP
Binyang2014 cbfe3c3
WIP
Binyang2014 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162; | |
|
|
||
| /// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15. | ||
| /// Format (MSB first): [sign:1][exponent:4][mantissa:3] | ||
| /// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe). | ||
| /// No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff). | ||
| /// Adapted from the Triton compiler's fp8e4b15 format. | ||
| struct alignas(1) __fp8_e4m3b15 { | ||
| uint8_t __x; | ||
|
|
@@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 { | |
| /// then convert fp16 → float32. | ||
| static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) { | ||
| // Branch-free decode: fp8 → fp16 → fp32, no special-case handling. | ||
| // Encode saturates to ±1.75, so 0x7f/0xff are never produced. | ||
| // Every byte maps to a finite value; encode saturates at ±1.875, so 0x7f/0xff decode to ±1.875. | ||
| // Refer: | ||
| // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34 | ||
| uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16 | ||
|
|
@@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 { | |
| } cvt = {h_val}; | ||
| uint16_t fp16_bits = cvt.u; | ||
|
|
||
| // Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00. | ||
| // Matches Triton: encode saturates, 0x7f/0xff are never produced. | ||
| // Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff). | ||
| uint16_t abs_fp16 = fp16_bits & 0x7FFFu; | ||
| if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u; | ||
| if (abs_fp16 > 0x3F80u) abs_fp16 = 0x3F80u; | ||
|
|
||
| // Reconstruct with sign. | ||
| uint16_t sign16 = fp16_bits & 0x8000u; | ||
|
|
@@ -852,27 +851,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) { | |
|
|
||
| /// f32x2 -> f8_e4m3x2. | ||
| /// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32). | ||
| /// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2). | ||
| /// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). | ||
| /// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a | ||
| /// single hardware round-to-nearest-even instruction; on older arch it falls back to a | ||
| /// software direct conversion. | ||
| template <> | ||
| MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) { | ||
| #if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) | ||
| uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); | ||
| return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed)); | ||
| #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 | ||
| __half2_raw h2; | ||
| h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0])); | ||
| h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1])); | ||
| __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); | ||
| return bit_cast<f8_e4m3x2>(fp8x2); | ||
| #elif defined(MSCCLPP_DEVICE_CUDA) | ||
| __half_raw h0, h1; | ||
| h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0])); | ||
| h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1])); | ||
| f8_e4m3x2 result; | ||
| result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); | ||
| result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); | ||
| return result; | ||
| __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E4M3); | ||
| return bit_cast<f8_e4m3x2>(fp8x2); | ||
|
Comment on lines
+863
to
+864
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work on CUDA 11? Are we going to deprecate CUDA 11? |
||
| #else | ||
| f8_e4m3x2 result; | ||
| result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); | ||
|
|
@@ -909,27 +898,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) { | |
|
|
||
| /// f32x2 -> f8_e5m2x2. | ||
| /// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32). | ||
| /// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2). | ||
| /// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). | ||
| /// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this | ||
| /// maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a | ||
| /// software direct conversion. | ||
| template <> | ||
| MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) { | ||
| #if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) | ||
| uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); | ||
| return bit_cast<f8_e5m2x2>(static_cast<__hip_fp8x2_storage_t>(packed)); | ||
| #elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 | ||
| __half2_raw h2; | ||
| h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0])); | ||
| h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1])); | ||
| __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2); | ||
| return bit_cast<f8_e5m2x2>(fp8x2); | ||
| #elif defined(MSCCLPP_DEVICE_CUDA) | ||
| __half_raw h0, h1; | ||
| h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0])); | ||
| h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1])); | ||
| f8_e5m2x2 result; | ||
| result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2)); | ||
| result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2)); | ||
| return result; | ||
| __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E5M2); | ||
| return bit_cast<f8_e5m2x2>(fp8x2); | ||
| #else | ||
| f8_e5m2x2 result; | ||
| result.data[0] = static_cast<__fp8_e5m2>(v.data[0]); | ||
|
|
@@ -1103,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) { | |
| #if defined(MSCCLPP_DEVICE_CUDA) | ||
| uint32_t in0; | ||
| asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v))); | ||
| // Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16). | ||
| // Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16). | ||
| uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16; | ||
| uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu; | ||
| alo = alo < 0x3F00u ? alo : 0x3F00u; | ||
| ahi = ahi < 0x3F00u ? ahi : 0x3F00u; | ||
| alo = alo < 0x3F80u ? alo : 0x3F80u; | ||
| ahi = ahi < 0x3F80u ? ahi : 0x3F80u; | ||
| uint32_t a0 = alo | (ahi << 16); | ||
| a0 = a0 * 2u + 0x00800080u; | ||
| uint32_t b0 = a0 | (in0 & 0x80008000u); | ||
|
|
@@ -1118,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) { | |
| uint32_t in0 = v.words[0]; | ||
| uint32_t abs0 = in0 & 0x7fff7fffu; | ||
| uint32_t a0; | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u)); | ||
| a0 = a0 * 2u + 0x00800080u; | ||
| uint32_t b0 = a0 | (in0 & 0x80008000u); | ||
| uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); | ||
|
|
@@ -1141,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) { | |
| asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1])); | ||
| uint32_t abs0 = in0 & 0x7fff7fffu; | ||
| uint32_t abs1 = in1 & 0x7fff7fffu; | ||
| uint32_t a0 = __vminu2(abs0, 0x3F003F00u); | ||
| uint32_t a1 = __vminu2(abs1, 0x3F003F00u); | ||
| uint32_t a0 = __vminu2(abs0, 0x3F803F80u); | ||
| uint32_t a1 = __vminu2(abs1, 0x3F803F80u); | ||
| a0 = a0 * 2u + 0x00800080u; | ||
| a1 = a1 * 2u + 0x00800080u; | ||
| uint32_t b0, b1; | ||
|
|
@@ -1155,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) { | |
| uint32_t in0 = v.words[0], in1 = v.words[1]; | ||
| uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu; | ||
| uint32_t a0, a1; | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u)); | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u)); | ||
| asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F803F80u)); | ||
| a0 = a0 * 2u + 0x00800080u; | ||
| a1 = a1 * 2u + 0x00800080u; | ||
| uint32_t b0 = a0 | (in0 & 0x80008000u); | ||
|
|
@@ -1268,8 +1247,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f32x4>(const f32x4& v) { | |
| return to<f8_e4m3b15x4, f16x4>(h); | ||
| #elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) | ||
| f16x4 h; | ||
| h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1])); | ||
| h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3])); | ||
| h.words[0] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[0], v.data[1])); | ||
| h.words[1] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[2], v.data[3])); | ||
| return to<f8_e4m3b15x4, f16x4>(h); | ||
| #else | ||
| f8_e4m3b15x4 result; | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,18 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| from .mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5 | ||
| __all__ = [ | ||
|
Binyang2014 marked this conversation as resolved.
|
||
| "MscclppAllReduce1", | ||
| "MscclppAllReduce2", | ||
| "MscclppAllReduce3", | ||
| "MscclppAllReduce4", | ||
| "MscclppAllReduce5", | ||
| ] | ||
|
|
||
|
|
||
| def __getattr__(name): | ||
| if name in __all__: | ||
| from . import mscclpp_op | ||
|
|
||
| return getattr(mscclpp_op, name) | ||
| raise AttributeError(f"module {__name__!r} has no attribute {name!r}") | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.