Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .azure-pipelines/templates/nccl-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ steps:
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN -x MSCCLPP_ENABLE_NCCL_FALLBACK=TRUE -x MSCCLPP_NCCL_LIB_PATH=/root/nccl/build/lib/libnccl.so -x MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION="broadcast" /root/nccl-tests/build/broadcast_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN -x MSCCLPP_ENABLE_NCCL_FALLBACK=TRUE -x MSCCLPP_NCCL_LIB_PATH=/root/nccl/build/lib/libnccl.so -x MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION="allreduce" /root/nccl-tests/build/broadcast_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20

- template: run-remote-task.yml
parameters:
name: PyBench
displayName: Run Collective Benchmarks
remoteScript: |
mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3b15 --accum-type float32 --autotune --symmetric-memory
mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3fn --accum-type float16 --autotune --symmetric-memory
mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float16 --symmetric-memory --autotune

- template: stop.yml
parameters:
subscription: ${{ parameters.subscription }}
Expand Down
9 changes: 9 additions & 0 deletions .azure-pipelines/templates/rccl-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ steps:
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN /root/rocm-systems/projects/rccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20
mpirun -np 8 --bind-to numa --allow-run-as-root /root/rocm-systems/projects/rccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20

- template: run-remote-task.yml
parameters:
name: PyBench
displayName: Run Collective Benchmarks
remoteScript: |
mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3b15 --accum-type float32 --autotune
mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3fnuz --accum-type float32 --autotune
mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allgather --dtype float8_e4m3b15 --autotune --buffer-mode out-of-place

Comment thread
Binyang2014 marked this conversation as resolved.
- template: stop.yml
parameters:
subscription: ${{ parameters.subscription }}
Expand Down
37 changes: 29 additions & 8 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ $ CXX=/opt/rocm/bin/hipcc python -m pip install ".[rocm6]"
```

> **Note:** A platform extra (`cuda11`, `cuda12`, `cuda13`, or `rocm6`) is required to install CuPy.
> The CUDA extras install pre-built CuPy wheels. The `rocm6` extra installs CuPy from source,
> which requires ROCm and may take longer. Running `pip install .` without an extra will not install CuPy.
> The CUDA extras install pre-built CuPy wheels and CUDA Python bindings. The `rocm6` extra installs CuPy from source
> and HIP Python 6.x, which require ROCm and may take longer. Running `pip install .` without an extra will not install CuPy.

Optional extras can be installed by specifying them in brackets. Available extras:
- **`cuda11`**, **`cuda12`**, **`cuda13`**: Install a pre-built CuPy package for your CUDA version.
- **`rocm6`**: Install CuPy from source for AMD ROCm platforms.
- **`cuda11`**, **`cuda12`**, **`cuda13`**: Install a pre-built CuPy package and CUDA Python bindings for your CUDA version.
- **`rocm6`**: Install CuPy from source and HIP Python 6.x for AMD ROCm platforms.
- **`benchmark`**: Install benchmark dependencies (mpi4py, prettytable, netifaces, matplotlib).
- **`test`**: Install test dependencies (pytest, mpi4py, netifaces).

Expand Down Expand Up @@ -209,15 +209,37 @@ $ mpirun -np 16 -npernode 8 -hostfile hostfile ./bin/mp_unit_tests -ip_port 10.0

## Performance Benchmark

### Python Benchmark
### Python Benchmark and Tuning

[Install the MSCCL++ Python package](#install-from-source-python-module) and run our Python AllReduce benchmark as follows. It requires MPI on the system.
[Install the MSCCL++ Python package](#install-from-source-python-module) and run the Python collective benchmark as follows. It requires MPI on the system.

```bash
# Install with benchmark dependencies and the appropriate CUDA/ROCm extras.
# Replace `cuda12` with your platform: cuda11, cuda12, cuda13, or rocm6.
$ python3 -m pip install ".[cuda12,benchmark,test]"
$ mpirun -tag-output -np 8 python3 ./python/mscclpp_benchmark/allreduce_bench.py

```

To autotune launch parameters and save a tuned config:

```bash
$ PYTHONPATH=$PWD/python mpirun -np 8 --allow-run-as-root \
python3 -m mscclpp_benchmark.bench_collective \
--collective allreduce \
--dtype float16 \
--batch-sizes 1,2,4,8 \
--autotune \
--write-config /tmp/mscclpp_tuned_configs.json
```

Use the tuned config in a benchmark:

```bash
$ PYTHONPATH=$PWD/python mpirun -np 8 --allow-run-as-root \
python3 -m mscclpp_benchmark.bench_collective \
--collective allreduce \
--dtype float16 \
--config-path /tmp/mscclpp_tuned_configs.json
```

(nccl-benchmark)=
Expand Down Expand Up @@ -291,4 +313,3 @@ Version: 0.8.0.post1.dev0+gc632fee37.d20251007
mscclpp.version
{'version': '0.8.0.post1.dev0+gc632fee37.d20251007', 'git_commit': 'g50382c567'}
```

69 changes: 24 additions & 45 deletions include/mscclpp/gpu_data_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162;

/// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15.
/// Format (MSB first): [sign:1][exponent:4][mantissa:3]
/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe).
/// No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff).
/// Adapted from the Triton compiler's fp8e4b15 format.
struct alignas(1) __fp8_e4m3b15 {
uint8_t __x;
Expand Down Expand Up @@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 {
/// then convert fp16 → float32.
static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) {
// Branch-free decode: fp8 → fp16 → fp32, no special-case handling.
// Encode saturates to ±1.75, so 0x7f/0xff are never produced.
// Every byte maps to a finite value; encode saturates at ±1.875, so 0x7f/0xff decode to ±1.875.
// Refer:
// https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34
uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16
Expand Down Expand Up @@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 {
} cvt = {h_val};
uint16_t fp16_bits = cvt.u;

// Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00.
// Matches Triton: encode saturates, 0x7f/0xff are never produced.
// Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff).
uint16_t abs_fp16 = fp16_bits & 0x7FFFu;
if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u;
if (abs_fp16 > 0x3F80u) abs_fp16 = 0x3F80u;

// Reconstruct with sign.
uint16_t sign16 = fp16_bits & 0x8000u;
Expand Down Expand Up @@ -852,27 +851,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to<f32x4, f8_e5m2x4>(const f8_e5m2x4& v) {

/// f32x2 -> f8_e4m3x2.
/// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32).
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2).
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a
/// single hardware round-to-nearest-even instruction; on older arch it falls back to a
/// software direct conversion.
template <>
MSCCLPP_DEVICE_INLINE f8_e4m3x2 to<f8_e4m3x2, f32x2>(const f32x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false);
return bit_cast<f8_e4m3x2>(static_cast<__hip_fp8x2_storage_t>(packed));
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2;
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3);
return bit_cast<f8_e4m3x2>(fp8x2);
#elif defined(MSCCLPP_DEVICE_CUDA)
__half_raw h0, h1;
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
f8_e4m3x2 result;
result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3));
result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3));
return result;
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E4M3);
return bit_cast<f8_e4m3x2>(fp8x2);
Comment on lines +863 to +864
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work on CUDA 11? Are we going to deprecate CUDA 11?

#else
f8_e4m3x2 result;
result.data[0] = static_cast<__fp8_e4m3>(v.data[0]);
Expand Down Expand Up @@ -909,27 +898,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to<f8_e4m3x4, f32x4>(const f32x4& v) {

/// f32x2 -> f8_e5m2x2.
/// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32).
/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2).
/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise).
/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this
/// maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a
/// software direct conversion.
template <>
MSCCLPP_DEVICE_INLINE f8_e5m2x2 to<f8_e5m2x2, f32x2>(const f32x2& v) {
#if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false);
return bit_cast<f8_e5m2x2>(static_cast<__hip_fp8x2_storage_t>(packed));
#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900
__half2_raw h2;
h2.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h2.y = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2);
return bit_cast<f8_e5m2x2>(fp8x2);
#elif defined(MSCCLPP_DEVICE_CUDA)
__half_raw h0, h1;
h0.x = bit_cast<unsigned short>(__float2half_rn(v.data[0]));
h1.x = bit_cast<unsigned short>(__float2half_rn(v.data[1]));
f8_e5m2x2 result;
result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2));
result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2));
return result;
__nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E5M2);
return bit_cast<f8_e5m2x2>(fp8x2);
#else
f8_e5m2x2 result;
result.data[0] = static_cast<__fp8_e5m2>(v.data[0]);
Expand Down Expand Up @@ -1103,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
#if defined(MSCCLPP_DEVICE_CUDA)
uint32_t in0;
asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast<const uint32_t*>(&v)));
// Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16).
// Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16).
uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16;
uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu;
alo = alo < 0x3F00u ? alo : 0x3F00u;
ahi = ahi < 0x3F00u ? ahi : 0x3F00u;
alo = alo < 0x3F80u ? alo : 0x3F80u;
ahi = ahi < 0x3F80u ? ahi : 0x3F80u;
uint32_t a0 = alo | (ahi << 16);
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
Expand All @@ -1118,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to<f8_e4m3b15x2, f16x2>(const f16x2& v) {
uint32_t in0 = v.words[0];
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t a0;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u));
a0 = a0 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u));
Expand All @@ -1141,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1]));
uint32_t abs0 = in0 & 0x7fff7fffu;
uint32_t abs1 = in1 & 0x7fff7fffu;
uint32_t a0 = __vminu2(abs0, 0x3F003F00u);
uint32_t a1 = __vminu2(abs1, 0x3F003F00u);
uint32_t a0 = __vminu2(abs0, 0x3F803F80u);
uint32_t a1 = __vminu2(abs1, 0x3F803F80u);
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0, b1;
Expand All @@ -1155,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f16x4>(const f16x4& v) {
uint32_t in0 = v.words[0], in1 = v.words[1];
uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu;
uint32_t a0, a1;
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u));
asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F803F80u));
a0 = a0 * 2u + 0x00800080u;
a1 = a1 * 2u + 0x00800080u;
uint32_t b0 = a0 | (in0 & 0x80008000u);
Expand Down Expand Up @@ -1268,8 +1247,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to<f8_e4m3b15x4, f32x4>(const f32x4& v) {
return to<f8_e4m3b15x4, f16x4>(h);
#elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__)
f16x4 h;
h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1]));
h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3]));
h.words[0] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[0], v.data[1]));
h.words[1] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[2], v.data[3]));
return to<f8_e4m3b15x4, f16x4>(h);
#else
f8_e4m3b15x4 result;
Expand Down
20 changes: 16 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ dependencies = [
]

[project.optional-dependencies]
cuda11 = ["cupy-cuda11x"]
cuda12 = ["cupy-cuda12x"]
cuda13 = ["cupy-cuda13x"]
rocm6 = ["cupy"]
cuda11 = [
"cupy-cuda11x",
"cuda-bindings>=11.8,<12",
]
cuda12 = [
"cupy-cuda12x",
"cuda-bindings>=12,<13",
]
cuda13 = [
"cupy-cuda13x",
"cuda-bindings>=13,<14",
]
rocm6 = [
"cupy",
"hip-python>=6,<7",
]
benchmark = [
"mpi4py",
"prettytable",
Expand Down
16 changes: 15 additions & 1 deletion python/mscclpp_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5
__all__ = [
Comment thread
Binyang2014 marked this conversation as resolved.
"MscclppAllReduce1",
"MscclppAllReduce2",
"MscclppAllReduce3",
"MscclppAllReduce4",
"MscclppAllReduce5",
]


def __getattr__(name):
if name in __all__:
from . import mscclpp_op

return getattr(mscclpp_op, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading
Loading