Skip to content

Commit cd9d504

Browse files
committed
add check for nvls fp8 support
1 parent 000cd5b commit cd9d504

13 files changed

Lines changed: 109 additions & 23 deletions

src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ struct NvlsBlockPipelineAdapter {
176176

177177
void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr<Communicator> comm) {
178178
nSwitchChannels_ = 8;
179+
fp8NvlsSupported_ = isFp8NvlsSupported();
179180
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
180181
// Per-peer channel allocation must hold up to 4 * nRanksPerIpcDomain entries (see kernel).
181182
int nBaseChannels = std::max(64, 4 * nRanksPerIpcDomain);
@@ -194,6 +195,10 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(
194195
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
195196
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
196197
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
198+
if (isNativeFp8DataType(dtype) && !fp8NvlsSupported_) {
199+
WARN("FP8 NVLS allreduce requires device support for FP8 multimem reduction.");
200+
return CommResult::CommInvalidArgument;
201+
}
197202
AllreduceFunc allreduce = dispatch<NvlsBlockPipelineAdapter>(op, dtype, accumDtype);
198203
if (!allreduce) {
199204
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));

src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ struct NvlsWarpPipelineAdapter {
140140

141141
void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr<Communicator> comm) {
142142
nSwitchChannels_ = NUM_NVLS_CONNECTION;
143+
fp8NvlsSupported_ = isFp8NvlsSupported();
143144
int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
144145
// Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * nRanksPerIpcDomain.
145146
int nBaseChannels = std::max(64, 8 * nRanksPerIpcDomain);
@@ -158,6 +159,10 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
158159
ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
159160
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras, DataType accumDtype) {
160161
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
162+
if (isNativeFp8DataType(dtype) && !fp8NvlsSupported_) {
163+
WARN("FP8 NVLS allreduce requires device support for FP8 multimem reduction.");
164+
return CommResult::CommInvalidArgument;
165+
}
161166
AllreduceFunc allreduce = dispatch<NvlsWarpPipelineAdapter>(op, dtype, accumDtype);
162167
if (!allreduce) {
163168
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));

src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ void AllreduceNvls::initialize(std::shared_ptr<mscclpp::Communicator> comm) {
9797
cudaDeviceProp deviceProp;
9898
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device));
9999
computeCapabilityMajor_ = deviceProp.major;
100+
fp8NvlsSupported_ = isFp8NvlsSupported();
100101
nSwitchChannels_ = 32;
101102
this->conns_ = setupConnections(comm);
102103
// setup semaphores
@@ -119,13 +120,10 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
119120
return CommResult::CommInvalidArgument;
120121
}
121122
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
122-
#if defined(__FP8_TYPES_EXIST__)
123-
bool isFp8Dtype = dtype == mscclpp::DataType::FLOAT8_E4M3FN || dtype == mscclpp::DataType::FLOAT8_E5M2;
124-
if (isFp8Dtype && computeCapabilityMajor_ < 10) {
125-
WARN("FP8 NVLS allreduce requires compute capability 10.x or newer.");
123+
if (isNativeFp8DataType(dtype) && !fp8NvlsSupported_) {
124+
WARN("FP8 NVLS allreduce requires device support for FP8 multimem reduction.");
126125
return CommResult::CommInvalidArgument;
127126
}
128-
#endif
129127
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype, accumDtype);
130128
if (!allreduce) {
131129
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));

src/ext/collectives/allreduce/allreduce_packet.cu

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,7 @@ inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int
197197

198198
// FP8-specific tuning for 32KB-256KB range
199199
{
200-
bool isFp8 = dtype == DataType::FLOAT8_E4M3B15;
201-
#if defined(__FP8_TYPES_EXIST__)
202-
isFp8 = isFp8 || dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
203-
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ;
204-
#endif
205-
if (isFp8) {
200+
if (isFp8DataType(dtype)) {
206201
if (inputSize < (64 << 10)) {
207202
nThreadsPerBlock = 64;
208203
} else if (inputSize >= (64 << 10) && inputSize <= (128 << 10)) {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,89 @@
66
#include <algorithm>
77
#include <mscclpp/algorithm.hpp>
88
#include <mscclpp/core.hpp>
9+
#include <mscclpp/gpu_utils.hpp>
910
#include <mscclpp/memory_channel.hpp>
1011
#include <mscclpp/switch_channel.hpp>
1112

1213
namespace mscclpp {
1314
namespace collective {
15+
16+
namespace {
17+
18+
#if !defined(MSCCLPP_DEVICE_HIP)
19+
__global__ void fp8NvlsSupportProbeKernel(int* supported) {
20+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
21+
(defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__))
22+
*supported = 1;
23+
#else
24+
*supported = 0;
25+
#endif
26+
}
27+
28+
bool detectFp8NvlsSupport() {
29+
AvoidCudaGraphCaptureGuard cgcGuard;
30+
auto supportedDevice = mscclpp::detail::gpuCallocUnique<int>();
31+
int supportedHost = 0;
32+
auto stream = gpuStreamPool()->getStream();
33+
34+
fp8NvlsSupportProbeKernel<<<1, 1, 0, stream>>>(supportedDevice.get());
35+
cudaError_t err = cudaGetLastError();
36+
if (err != cudaSuccess) {
37+
return false;
38+
}
39+
40+
MSCCLPP_CUDATHROW(cudaMemcpyAsync(&supportedHost, supportedDevice.get(), sizeof(supportedHost),
41+
cudaMemcpyDeviceToHost, stream));
42+
err = cudaStreamSynchronize(stream);
43+
if (err != cudaSuccess) {
44+
(void)cudaGetLastError();
45+
return false;
46+
}
47+
return supportedHost != 0;
48+
}
49+
#endif
50+
51+
} // namespace
52+
53+
bool isFp8DataType(DataType dtype) {
54+
return dtype == DataType::FLOAT8_E4M3FN || dtype == DataType::FLOAT8_E4M3FNUZ ||
55+
dtype == DataType::FLOAT8_E5M2 || dtype == DataType::FLOAT8_E5M2FNUZ ||
56+
dtype == DataType::FLOAT8_E4M3B15;
57+
}
58+
59+
bool isNativeFp8DataType(DataType dtype) {
60+
#if defined(__FP8_TYPES_EXIST__)
61+
#if defined(__FP8_E4M3_IS_FNUZ__)
62+
if (dtype == DataType::FLOAT8_E4M3FNUZ) {
63+
return true;
64+
}
65+
#else
66+
if (dtype == DataType::FLOAT8_E4M3FN) {
67+
return true;
68+
}
69+
#endif
70+
#if defined(__FP8_E5M2_IS_FNUZ__)
71+
if (dtype == DataType::FLOAT8_E5M2FNUZ) {
72+
return true;
73+
}
74+
#else
75+
if (dtype == DataType::FLOAT8_E5M2) {
76+
return true;
77+
}
78+
#endif
79+
#endif
80+
return false;
81+
}
82+
83+
bool isFp8NvlsSupported() {
84+
#if defined(MSCCLPP_DEVICE_HIP)
85+
return false;
86+
#else
87+
static const bool supported = detectFp8NvlsSupport();
88+
return supported;
89+
#endif
90+
}
91+
1492
std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
1593
mscclpp::RegisteredMemory localMemory) {
1694
std::vector<mscclpp::RegisteredMemory> remoteMemories;

src/ext/collectives/include/allreduce/allreduce_nvls_block_pipeline.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AllreduceNvlsBlockPipeline : public AlgorithmBuilder {
3333
std::vector<BaseMemoryChannel> baseChannels_;
3434
std::vector<Connection> conns_;
3535
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
36+
bool fp8NvlsSupported_{false};
3637
};
3738
} // namespace collective
3839
} // namespace mscclpp

src/ext/collectives/include/allreduce/allreduce_nvls_warp_pipeline.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class AllreduceNvlsWarpPipeline : public AlgorithmBuilder {
3333
std::vector<BaseMemoryChannel> baseChannels_;
3434
std::vector<Connection> conns_;
3535
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
36+
bool fp8NvlsSupported_{false};
3637
};
3738
} // namespace collective
3839
} // namespace mscclpp

src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class AllreduceNvls : public AlgorithmBuilder {
3636
std::vector<std::shared_ptr<NvlsConnection>> nvlsConnections_;
3737
std::vector<std::shared_ptr<NvlsConnection>> nvlsOutConnections_;
3838
int computeCapabilityMajor_{0};
39+
bool fp8NvlsSupported_{false};
3940
};
4041

4142
} // namespace collective

src/ext/collectives/include/collective_utils.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ constexpr int MAX_IPC_DOMAIN_NRANKS = 72;
3939

4040
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // Two 70 MiB buffers for double-buffered packet scratch space.
4141

42+
bool isFp8DataType(DataType dtype);
43+
bool isNativeFp8DataType(DataType dtype);
44+
bool isFp8NvlsSupported();
45+
4246
std::vector<RegisteredMemory> setupRemoteMemories(std::shared_ptr<Communicator> comm, int rank,
4347
RegisteredMemory localMemory);
4448

src/ext/nccl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ target_include_directories(mscclpp_nccl PRIVATE
1313
include
1414
${PROJECT_SOURCE_DIR}/include
1515
${PROJECT_SOURCE_DIR}/src/core/include
16+
${PROJECT_SOURCE_DIR}/src/ext/collectives/include
1617
${GPU_INCLUDE_DIRS}
1718
)
1819
target_link_libraries(mscclpp_nccl PUBLIC mscclpp mscclpp_collectives)

0 commit comments

Comments
 (0)