diff --git a/paddle/fluid/distributed/collective/deep_ep/CMakeLists.txt b/paddle/fluid/distributed/collective/deep_ep/CMakeLists.txt index 6d1a63b6c04d3..d02f291d3d650 100644 --- a/paddle/fluid/distributed/collective/deep_ep/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/deep_ep/CMakeLists.txt @@ -7,8 +7,13 @@ if(WITH_NVSHMEM) CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") set(DEEPEP_KERNEL_SRCS - kernels/intranode.cu kernels/runtime.cu kernels/internode.cu - kernels/internode_ll.cu kernels/internode_ll_two_stage.cu) + kernels/intranode.cu + kernels/runtime.cu + kernels/internode.cu + kernels/internode_ll.cu + kernels/internode_ll_two_stage.cu + kernels/internode_ll.cu + kernels/m2n_ll_two_stage.cu) cc_library( deepep_kernels SRCS ${DEEPEP_KERNEL_SRCS} diff --git a/paddle/fluid/distributed/collective/deep_ep/config.hpp b/paddle/fluid/distributed/collective/deep_ep/config.hpp index b32821a12ad6f..737e0eaa83963 100644 --- a/paddle/fluid/distributed/collective/deep_ep/config.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/config.hpp @@ -149,10 +149,14 @@ struct LowLatencyBuffer { void* dispatch_rdma_send_buffer = nullptr; void* dispatch_rdma_recv_data_buffer = nullptr; int* dispatch_rdma_recv_count_buffer = nullptr; + // Note(ZKK) this is only used in M2N ! + int* dispatch_rdma_recv_complete_buffer = nullptr; void* combine_rdma_send_buffer = nullptr; void* combine_rdma_recv_data_buffer = nullptr; int* combine_rdma_recv_flag_buffer = nullptr; + // Note(ZKK) this is only used in M2N ! + int* combine_rdma_recv_complete_buffer = nullptr; void* combine_rdma_send_buffer_data_start = nullptr; size_t num_bytes_per_combine_msg = 0; @@ -244,11 +248,19 @@ struct LowLatencyLayout { advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + // Note(ZKK): dispatch_rdma_recv_complete_buffer is only used in M2N! + // so here we symbolically add a 0 to it + advance(rdma_buffer, 0), + advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + // Note(ZKK): combine_rdma_recv_complete_buffer is only used in M2N! + // so here we symbolically add a 0 to it + advance(rdma_buffer, 0), + advance(rdma_buffer, send_buffer_bytes * i), num_bytes_per_combine_msg}; } @@ -318,6 +330,12 @@ struct LowLatencyTwoStageLayout { combine_recv_flag_buffer_bytes); total_bytes += signaling_buffer_bytes * 2; + // Symmetric complete signaling buffers + // Note(ZKK): this is only used in M2N! + size_t recv_complete_buffer_bytes = + 2 * M2N_NUM_MAX_MICRO_BATCHES * num_ranks * sizeof(int); + total_bytes += recv_complete_buffer_bytes * 2; + // Assign pointers for (int i = 0; i < 2; ++i) { buffers[i] = { @@ -327,11 +345,21 @@ struct LowLatencyTwoStageLayout { advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + // dispatch_rdma_recv_complete_buffer! + advance(rdma_buffer, + send_buffer_bytes * 2 + recv_buffer_bytes * 2 + + signaling_buffer_bytes * 2 + + recv_complete_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + // combine_rdma_recv_complete_buffer! + advance(rdma_buffer, + send_buffer_bytes * 2 + recv_buffer_bytes * 2 + + signaling_buffer_bytes * 2 + + recv_complete_buffer_bytes * i), advance(rdma_buffer, send_buffer_bytes * i), num_bytes_per_combine_msg}; } diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index ac82ab2f0feb1..8cf6231bc16bf 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -138,8 +138,12 @@ Buffer::Buffer(int rank, } // Create 32 MiB workspace - CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); - CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + // Note(ZKK): here we allocate more(2 * M2N_NUM_WORKSPACE) to support M2N! + // Later we will opitimize here! + CUDA_CHECK( + cudaMalloc(&workspace, 2 * M2N_NUM_WORKSPACE * NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync( + workspace, 0, 2 * M2N_NUM_WORKSPACE * NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter CUDA_CHECK( @@ -172,7 +176,7 @@ Buffer::Buffer(int rank, Buffer::~Buffer() noexcept(false) { // Synchronize CUDA_CHECK(cudaDeviceSynchronize()); - + printf("Buffer::~Buffer begin!!!\n"); if (num_nvl_bytes > 0) { // Barrier intranode::barrier( @@ -2307,6 +2311,431 @@ Buffer::low_latency_combine_two_stage( // Return values return {combined_x, event, recv_hook}; } + +std::tuple, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + std::optional, + std::optional>> +Buffer::m2n_low_latency_dispatch_two_stage( + const deep_ep::detail::Tensor& x, + const deep_ep::detail::Tensor& topk_idx, + const deep_ep::detail::Tensor& topk_weights, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + bool async, + bool return_recv_hook) { + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous() && + x.scalar_type() == deep_ep::detail::kBFloat16); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 && x.size(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) && + x.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + + auto num_tokens = static_cast(x.size(0)), + hidden = static_cast(x.size(1)); + auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + int num_local_experts = num_experts / num_ranks; + + // Buffer control + LowLatencyTwoStageLayout layout(rdma_buffer_ptr, + num_max_dispatch_tokens_per_rank, + hidden, + num_ranks, + num_experts, + num_topk); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + // fixed buffer, 0 for dispatch, 1 for combine + auto buffer = layout.buffers[0]; + auto next_buffer = layout.buffers[1]; + auto dispatch_workspace = reinterpret_cast( + reinterpret_cast(workspace) + + m2n_ll_dispatch_workspace_idx * NUM_WORKSPACE_BYTES); + m2n_ll_dispatch_workspace_idx = + (m2n_ll_dispatch_workspace_idx + 1) % M2N_NUM_WORKSPACE; + auto dispatch_rdma_recv_complete = + buffer.dispatch_rdma_recv_complete_buffer + + m2n_ll_dispatch_recv_complete_idx * num_ranks; + m2n_ll_dispatch_recv_complete_idx = + (m2n_ll_dispatch_recv_complete_idx + 1) % M2N_NUM_MAX_MICRO_BATCHES; + + // Wait previous tasks to be finished + // NOTES: the hook mode will always use the default stream + // auto compute_stream = calc_ctx->stream(); + // auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + // EP_HOST_ASSERT(!(async && return_recv_hook)); + // if (!return_recv_hook) stream_wait(launch_stream, compute_stream); + + auto compute_stream = calc_ctx->stream(); + auto launch_stream = comm_stream; + if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + stream_wait(launch_stream, compute_stream); + } + + if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + stream_wait(compute_stream, launch_stream); + } + + auto return_x_dtype = phi::DataType::BFLOAT16; + if (use_fp8) { + return_x_dtype = phi::DataType::FLOAT8_E4M3FN; + } + + // Allocate packed tensors + auto packed_recv_x = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_local_experts, + num_ranks * num_max_dispatch_tokens_per_rank, + hidden}, + return_x_dtype, + x.place())); + auto rdma_send_flags = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, + phi::DataType::BOOL, + phi::GPUPlace(device_id))); + auto packed_recv_src_info = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto packed_recv_layout_range = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_local_experts, num_ranks}, + phi::DataType::INT64, + phi::GPUPlace(device_id))); + auto packed_recv_count = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id))); + auto packed_rdma_recv_count = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + const size_t num_bytes_per_msg = + sizeof(int4) + + (num_ranks / NUM_MAX_NVL_PEERS * (num_topk * 3 + 1) * sizeof(int) + + sizeof(int4) - 1) / + sizeof(int4) * sizeof(int4) + + (use_fp8 ? (hidden + num_scales * sizeof(float)) + : (hidden * sizeof(nv_bfloat16))); + auto packed_rdma_recv_x = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_ranks / NUM_MAX_NVL_PEERS, + num_max_dispatch_tokens_per_rank, + num_bytes_per_msg}, + phi::DataType::UINT8, + phi::GPUPlace(device_id))); + + // Allocate column-majored scales + auto packed_recv_x_scales = std::optional(); + float* packed_recv_x_scales_ptr = nullptr; + if (use_fp8) { + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 && + "TMA requires the number of tokens to be multiple of 4"); + packed_recv_x_scales = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_local_experts, + num_scales, + num_ranks * num_max_dispatch_tokens_per_rank}, + phi::DataType::FLOAT32, + phi::GPUPlace(device_id))); + packed_recv_x_scales = + ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose( + ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()), + std::vector{0, 2, 1})); + packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr(); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + auto launcher = [=](int phases) { + m2n_ll_two_stage::dispatch(packed_recv_x.data_ptr(), + packed_recv_x_scales_ptr, + packed_rdma_recv_x.data_ptr(), + packed_recv_src_info.data_ptr(), + packed_recv_layout_range.data_ptr(), + packed_recv_count.data_ptr(), + packed_rdma_recv_count.data_ptr(), + rdma_send_flags.data_ptr(), + buffer.dispatch_rdma_recv_data_buffer, + buffer.dispatch_rdma_recv_count_buffer, + dispatch_rdma_recv_complete, + buffer.dispatch_rdma_send_buffer, + buffer_ptrs_gpu, + x.data_ptr(), + topk_idx.data_ptr(), + topk_weights.data_ptr(), + next_clean_meta.first, + next_clean_meta.second, + num_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + use_fp8, + dispatch_workspace, + launch_stream, + phases); + }; + + // TODO(Zhenyu Li): supports async/return_recv_hook + launcher(return_recv_hook + ? LOW_LATENCY_SEND_PHASE + : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + // Wait streams + // std::optional event; + // if (async) { + // // NOTES: we must ensure the all tensors will not be deallocated before + // the + // // stream-wait happens, so in Python API, we must wrap all tensors into + // the + // // event handle. + // event = EventHandle(launch_stream); + // } else if (!return_recv_hook) { + // stream_wait(compute_stream, launch_stream); + // } + + std::optional event; + if (async) { + // NOTES: we must ensure the all tensors will not be deallocated before the + // stream-wait happens, so in Python API, we must wrap all tensors into the + // event handle. + event = EventHandle(launch_stream); + } + // // stream_wait(launch_stream, compute_stream); + // if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + // stream_wait(compute_stream, launch_stream); + // } + + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { + // stream_wait(launch_stream, compute_stream); + launcher(LOW_LATENCY_RECV_PHASE); + // stream_wait(compute_stream, launch_stream); + + // if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + // stream_wait(compute_stream, launch_stream); + // } + return EventHandle(launch_stream); + }; + + return {packed_recv_x, + packed_recv_x_scales, + packed_rdma_recv_x, + packed_recv_count, + packed_rdma_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + rdma_send_flags, + event, + recv_hook}; +} + +std::tuple, + std::optional>> +Buffer::m2n_low_latency_combine_two_stage( + const deep_ep::detail::Tensor& x, + const deep_ep::detail::Tensor& rdma_recv_x, + const deep_ep::detail::Tensor& topk_idx, + const deep_ep::detail::Tensor& topk_weights, + const deep_ep::detail::Tensor& src_info, + const deep_ep::detail::Tensor& layout_range, + const deep_ep::detail::Tensor& rdma_send_flags, + const deep_ep::detail::Tensor& dispatch_rdma_recv_count, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool dispatch_use_fp8, + bool async, + bool return_recv_hook, + const std::optional& out) { + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + EP_HOST_ASSERT(x.dim() == 3 && x.is_contiguous() && + x.scalar_type() == deep_ep::detail::kBFloat16); + EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 && x.size(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 && topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) && + topk_idx.size(1) == topk_weights.size(1)); + EP_HOST_ASSERT(topk_idx.scalar_type() == deep_ep::detail::kInt64); + EP_HOST_ASSERT(topk_weights.dim() == 2 && topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.scalar_type() == deep_ep::detail::kFloat32); + EP_HOST_ASSERT(src_info.dim() == 2 && src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.scalar_type() == deep_ep::detail::kInt32 && + x.size(0) == src_info.size(0)); + EP_HOST_ASSERT(layout_range.dim() == 2 && layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.scalar_type() == deep_ep::detail::kInt64); + EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks && + layout_range.size(1) == num_ranks); + auto hidden = static_cast(x.size(2)); + auto num_local_experts = num_experts / num_ranks, + num_topk = static_cast(topk_weights.size(1)); + auto num_combined_tokens = static_cast(topk_weights.size(0)); + + // Buffer control + LowLatencyTwoStageLayout layout(rdma_buffer_ptr, + num_max_dispatch_tokens_per_rank, + hidden, + num_ranks, + num_experts, + num_topk); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + // fixed buffer, 0 for dispatch, 1 for combine + auto dispatch_buffer = layout.buffers[0]; + auto buffer = layout.buffers[1]; + auto next_buffer = layout.buffers[0]; + auto combine_workspace = reinterpret_cast( + reinterpret_cast(workspace) + + (M2N_NUM_WORKSPACE + m2n_ll_combine_workspace_idx) * NUM_WORKSPACE_BYTES); + m2n_ll_combine_workspace_idx = + (m2n_ll_combine_workspace_idx + 1) % M2N_NUM_WORKSPACE; + auto combine_rdma_recv_complete = + buffer.combine_rdma_recv_complete_buffer + + m2n_ll_combine_recv_complete_idx * num_ranks; + m2n_ll_combine_recv_complete_idx = + (m2n_ll_combine_recv_complete_idx + 1) % M2N_NUM_MAX_MICRO_BATCHES; + + // Wait previous tasks to be finished + // NOTES: the hook mode will always use the default stream + // auto compute_stream = calc_ctx->stream(); + // auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + // EP_HOST_ASSERT(!(async && return_recv_hook)); + // if (!return_recv_hook) stream_wait(launch_stream, compute_stream); + + auto compute_stream = calc_ctx->stream(); + auto launch_stream = comm_stream; + if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + stream_wait(launch_stream, compute_stream); + } + + if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + stream_wait(compute_stream, launch_stream); + } + + // Allocate output tensor + deep_ep::detail::Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->dim() == 2 && out->is_contiguous()); + EP_HOST_ASSERT(out->size(0) == num_combined_tokens && + out->size(1) == hidden); + EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); + combined_x = out.value(); + } else { + combined_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_combined_tokens, hidden}, x.dtype(), x.place())); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + auto launcher = [=](int phases) { + m2n_ll_two_stage::combine(combined_x.data_ptr(), + buffer.combine_rdma_recv_data_buffer, + buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, + combine_rdma_recv_complete, + rdma_recv_x.data_ptr(), + dispatch_rdma_recv_count.data_ptr(), + buffer_ptrs_gpu, + x.data_ptr(), + topk_idx.data_ptr(), + topk_weights.data_ptr(), + src_info.data_ptr(), + layout_range.data_ptr(), + rdma_send_flags.data_ptr(), + next_clean_meta.first, + next_clean_meta.second, + num_combined_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + combine_workspace, + launch_stream, + phases, + dispatch_use_fp8); + }; + // TODO(Zhenyu Li): supports async/return_recv_hook + launcher(return_recv_hook + ? LOW_LATENCY_SEND_PHASE + : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + // Wait streams + // std::optional event; + // if (async) { + // // NOTES: we must ensure the all tensors will not be deallocated before + // the + // // stream-wait happens, so in Python API, we must wrap all tensors into + // the + // // event handle. + // event = EventHandle(launch_stream); + // } else if (!return_recv_hook) { + // stream_wait(compute_stream, launch_stream); + // } + + std::optional event; + if (async) { + // NOTES: we must ensure the all tensors will not be deallocated before the + // stream-wait happens, so in Python API, we must wrap all tensors into the + // event handle. + event = EventHandle(launch_stream); + } + // // stream_wait(launch_stream, compute_stream); + // if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + // stream_wait(compute_stream, launch_stream); + // } + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { + // stream_wait(launch_stream, compute_stream); + launcher(LOW_LATENCY_RECV_PHASE); + // stream_wait(compute_stream, launch_stream); + // stream_wait(launch_stream, compute_stream); + // if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + // stream_wait(compute_stream, launch_stream); + // } + return EventHandle(launch_stream); + }; + + // Return values + return {combined_x, event, recv_hook}; +} + #endif // PADDLE_WITH_NVSHMEM std::tuple, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + std::optional, + std::optional>> +Buffer::m2n_low_latency_dispatch_two_stage_api( + const paddle::Tensor& x, + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + bool async, + bool return_recv_hook) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); + const auto& topk_weights_ = ConvertPaddleTensorToDetailTensor(topk_weights); + + auto res = + m2n_low_latency_dispatch_two_stage(x_, + topk_idx_, + topk_weights_, + num_max_dispatch_tokens_per_rank, + num_experts, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + use_fp8, + async, + return_recv_hook); + + auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res)); + + std::optional packed_recv_x_scales_; + if (std::get<1>(res).has_value()) { + packed_recv_x_scales_ = + ConvertDetailTensorToPaddleTensor(std::get<1>(res).value()); + } + auto packed_recv_rdma_x_ = + ConvertDetailTensorToPaddleTensor(std::get<2>(res)); + auto packed_recv_count_ = ConvertDetailTensorToPaddleTensor(std::get<3>(res)); + auto packed_rdma_recv_count_ = + ConvertDetailTensorToPaddleTensor(std::get<4>(res)); + auto packed_recv_src_info_ = + ConvertDetailTensorToPaddleTensor(std::get<5>(res)); + auto packed_recv_layout_range_ = + ConvertDetailTensorToPaddleTensor(std::get<6>(res)); + auto rdma_send_flags_ = ConvertDetailTensorToPaddleTensor(std::get<7>(res)); + + const auto& event = std::get<8>(res); + auto recv_hook = std::get<9>(res); + + return {packed_recv_x_, + packed_recv_x_scales_, + packed_recv_rdma_x_, + packed_recv_count_, + packed_rdma_recv_count_, + packed_recv_src_info_, + packed_recv_layout_range_, + rdma_send_flags_, + event, + recv_hook}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + +std::tuple, + std::optional>> +Buffer::m2n_low_latency_combine_two_stage_api( + const paddle::Tensor& x, + const paddle::Tensor& rdma_recv_x, + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights, + const paddle::Tensor& src_info, + const paddle::Tensor& layout_range, + const paddle::Tensor& rdma_send_flags, + const paddle::Tensor& dispatch_rdma_recv_count, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool dispatch_use_fp8, + bool async, + bool return_recv_hook, + const std::optional& out) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + const auto& rdma_recv_x_ = ConvertPaddleTensorToDetailTensor(rdma_recv_x); + const auto& topk_idx_ = ConvertPaddleTensorToDetailTensor(topk_idx); + const auto& topk_weights_ = ConvertPaddleTensorToDetailTensor(topk_weights); + const auto& src_info_ = ConvertPaddleTensorToDetailTensor(src_info); + const auto& layout_range_ = ConvertPaddleTensorToDetailTensor(layout_range); + const auto& rdma_send_flags_ = + ConvertPaddleTensorToDetailTensor(rdma_send_flags); + const auto& dispatch_rdma_recv_count_ = + ConvertPaddleTensorToDetailTensor(dispatch_rdma_recv_count); + + std::optional out_ = std::nullopt; + if (out.has_value()) { + out_ = ConvertOptionalPaddleTensorToDetailTensor(out.value()); + } + + auto res = m2n_low_latency_combine_two_stage(x_, + rdma_recv_x_, + topk_idx_, + topk_weights_, + src_info_, + layout_range_, + rdma_send_flags_, + dispatch_rdma_recv_count_, + num_max_dispatch_tokens_per_rank, + num_experts, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + dispatch_use_fp8, + async, + return_recv_hook, + out_); + + auto combined_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res)); + const auto& event = std::get<1>(res); + auto recv_hook = std::get<2>(res); + + return {combined_x_, event, recv_hook}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + std::tuple, paddle::Tensor, diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index f0c3b69c3ffad..e6620a37d03c8 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -52,6 +52,10 @@ struct Buffer { // Low-latency mode buffer int low_latency_buffer_idx = 0; bool low_latency_mode = false; + int m2n_ll_dispatch_workspace_idx = 0; + int m2n_ll_combine_workspace_idx = 0; + int m2n_ll_dispatch_recv_complete_idx = 0; + int m2n_ll_combine_recv_complete_idx = 0; // NVLink Buffer int64_t num_nvl_bytes; @@ -327,6 +331,53 @@ struct Buffer { bool return_recv_hook, const std::optional& out); + std::tuple, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + std::optional, + std::optional>> + m2n_low_latency_dispatch_two_stage( + const deep_ep::detail::Tensor& x, + const deep_ep::detail::Tensor& topk_idx, + const deep_ep::detail::Tensor& topk_weights, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + bool async, + bool return_recv_hook); + + std::tuple, + std::optional>> + m2n_low_latency_combine_two_stage( + const deep_ep::detail::Tensor& x, + const deep_ep::detail::Tensor& rdma_recv_x, + const deep_ep::detail::Tensor& topk_idx, + const deep_ep::detail::Tensor& topk_weights, + const deep_ep::detail::Tensor& src_info, + const deep_ep::detail::Tensor& layout_range, + const deep_ep::detail::Tensor& rdma_send_flags, + const deep_ep::detail::Tensor& dispatch_rdma_recv_count, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool dispatch_use_fp8, + bool async, + bool return_recv_hook, + const std::optional& out); + #endif // PADDLE_WITH_NVSHMEM std::tuple& out); + std::tuple, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + std::optional, + std::optional>> + m2n_low_latency_dispatch_two_stage_api(const paddle::Tensor& x, + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + bool async, + bool return_recv_hook); + + std::tuple, + std::optional>> + m2n_low_latency_combine_two_stage_api( + const paddle::Tensor& x, + const paddle::Tensor& rdma_recv_x, + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights, + const paddle::Tensor& src_info, + const paddle::Tensor& layout_range, + const paddle::Tensor& rdma_send_flags, + const paddle::Tensor& dispatch_rdma_recv_count, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool dispatch_use_fp8, + bool async, + bool return_recv_hook, + const std::optional& out); + std::tuple, paddle::Tensor, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh index 35fbba5a1c373..24f041f23c4dd 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh @@ -421,6 +421,76 @@ void clean_low_latency_buffer_two_stage(void** buffer_ptrs_gpu, } // namespace internode_ll_two_stage +namespace m2n_ll_two_stage { + +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + void* packed_rdma_recv_x, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* packed_rdma_recv_count, + bool* rdma_send_flags, + void* rdma_recv_x, + int* rdma_recv_count, + int* rdma_recv_complete, + void* rdma_x, + void** nvl_recv_x, + const void* x, + const int64_t* topk_idx, + const float* topk_weights, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + void* workspace, + cudaStream_t stream, + int phases); + +void combine(void* combined_x, + void* rdma_recv_x, + int* rdma_recv_flag, + void* rdma_send_x, + int* rdma_recv_complete, + void* dispatch_rdma_recv_x, + const int* dispatch_rdma_recv_count, + void** nvl_buffer, + const void* x, + const int64_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + const bool* rdma_send_flags, + int* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + void* workspace, + cudaStream_t stream, + int phases, + bool dispatch_use_fp8); + +} // namespace m2n_ll_two_stage + #endif // PADDLE_WITH_NVSHMEM } // namespace deep_ep diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh index 4d2036b55e53d..c2ffaefb9a3e9 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh @@ -24,6 +24,8 @@ #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_BUFFER_ALIGNMENT_BYTES 128 +#define M2N_NUM_MAX_MICRO_BATCHES 51 +#define M2N_NUM_WORKSPACE 3 #define FINISHED_SUM_TAG 1024 #define NUM_WAIT_NANOSECONDS 500 diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh index ba9b8be9cdf37..4cae5d8f19f60 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh @@ -209,6 +209,9 @@ } else if (num_warp_groups == 4) { \ constexpr int kNumWarpGroups = 4; \ __VA_ARGS__ \ + } else if (num_warp_groups == 8) { \ + constexpr int kNumWarpGroups = 8; \ + __VA_ARGS__ \ } else { \ EP_HOST_ASSERT(false && "Unsupported num_warp_groups"); \ } diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/m2n_ll_two_stage.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/m2n_ll_two_stage.cu new file mode 100644 index 0000000000000..63ebcd2cd239f --- /dev/null +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/m2n_ll_two_stage.cu @@ -0,0 +1,1567 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#include +#include +#include +#include +#include +// clang-format on +#include "paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh" +#include "paddle/fluid/distributed/collective/deep_ep/kernels/exception.cuh" +#include "paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh" +#include "paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh" + +namespace deep_ep { + +namespace m2n_ll_two_stage { + +constexpr bool M2N_LL_DEBUG = false; +constexpr bool M2N_LL_ACC_DEBUG = false; +constexpr bool M2N_LL_HANG_DEBUG = true; +constexpr int64_t M2N_NUM_HANG_CYCLES = 2000000000; // 345MHZ 5.8s; + +template +__global__ __launch_bounds__( + kNumWarpGroups* kNumWarpsPerGroup * 32, + 1) void dispatch_kernel(void* packed_recv_x, + float* packed_recv_x_scales, + void* packed_rdma_recv_x, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* packed_rdma_recv_count, + bool* rdma_send_flags, // kNumRdmaRanks + void* rdma_recv_x, + int* rdma_recv_count, + int* rdma_recv_complete, + void* rdma_x, + void** nvl_recv_x, // num_local_experts * dp_num * + // num_max_token_per_dp * + // hidden_size + const void* x, + const int64_t* topk_idx, + const float* topk_weights, + int* atomic_counter_per_expert, + int* atomic_counter_per_rdma, + int* atomic_finished_counter_per_rdma, + int* atomic_recv_tokens_per_rdma_expert, + int* atomic_nvl_sender_multi_sms, + int* atomic_counter_per_qp, + int num_tokens, + int num_max_dispatch_tokens_per_rank, + int rank, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + int phases) { + constexpr int UNROLL_FACTOR = kHidden / 1024; + constexpr int kNumRanks = kNumRdmaRanks * NUM_MAX_NVL_PEERS; + constexpr int kNumLocalExperts = kNumExperts / kNumRanks; + constexpr int kNumRdmaExperts = kNumLocalExperts * NUM_MAX_NVL_PEERS; + + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); + const auto num_threads = static_cast(blockDim.x), + num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), + warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; + int a_start_rdma_rank = a_start_rank / NUM_MAX_NVL_PEERS; + int a_num_rdma_ranks = a_num_ranks / NUM_MAX_NVL_PEERS; + int e_start_rdma_rank = e_start_rank / NUM_MAX_NVL_PEERS; + int e_num_rdma_ranks = e_num_ranks / NUM_MAX_NVL_PEERS; + + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + const int qp_id = sm_id % kNumQPs; + // check + if (sm_id == 0 && thread_id == 0) { + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= kNumQPs); + } + + // FP8 staffs + constexpr int kNumPerChannels = 128; + constexpr float kFP8Margin = 1e-4, kFP8Amax = 448, + kFP8AmaxInv = 1.0f / 448.0f; + constexpr int kNumScales = kHidden / kNumPerChannels; + const size_t hidden_bytes = + kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); + + // index_source, hidden, (scale), nvl_valid_num, nvl_rank0, dst_idx0, + // topk_weight0, + // ..., nvl_rank8, dst_idx8, topk_weight8, ... + using vec_t = typename std::conditional::type; + const size_t num_bytes_per_msg = + sizeof(int4) + + (kNumRdmaRanks * (kTopk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / + sizeof(int4) * sizeof(int4) + + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); + // rdma_index_source, hidden, (scale) + const size_t num_bytes_per_msg_rdma_revecier_and_nvl_sender = + sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); + const size_t NVL_BUFFER_X_BYTES = + kNumLocalExperts * kNumRanks * num_max_dispatch_tokens_per_rank * + num_bytes_per_msg_rdma_revecier_and_nvl_sender; + const size_t num_bytes_per_msg_rdma_to_nvl = + kUseFP8 ? (kHidden + kNumScales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16)); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + const size_t num_int4_per_msg_rdma_revecier_and_nvl_sender = + num_bytes_per_msg_rdma_revecier_and_nvl_sender / sizeof(int4); + const size_t num_int4_per_msg_rdma_to_nvl = + num_bytes_per_msg_rdma_to_nvl / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + EP_DEVICE_ASSERT( + num_bytes_per_msg_rdma_revecier_and_nvl_sender % sizeof(int4) == 0); + EP_DEVICE_ASSERT(num_bytes_per_msg_rdma_to_nvl % sizeof(int4) == 0); + + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; + + /* RDMA Sender */ + { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, + "Invalid vectorization"); + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = + reinterpret_cast(x) + token_idx * hidden_bf16_int4; + bool* rdma_send_flags_now = rdma_send_flags + token_idx * kNumRdmaRanks; + +// init rdma_send_flags +#pragma unroll + for (int flag_i = thread_id; flag_i < kNumRdmaRanks; + flag_i += num_threads) { + rdma_send_flags_now[flag_i] = false; + } + const auto rdma_x_src_idx = reinterpret_cast( + reinterpret_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast( + reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast( + reinterpret_cast(rdma_x_vec) + hidden_bytes); + + const auto nvl_rank_meta = + reinterpret_cast(rdma_x_scales + (kUseFP8 ? kNumScales : 0)); + + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + +#pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, + "Invalid vectorization"); + amax = half_warp_reduce_max(amax), scale = kFP8Amax / amax, + scale_inv = amax * kFP8AmaxInv; + if (lane_id == 0 || lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = + reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); +#pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, + fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = + __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } + __syncthreads(); + + // Only need issue to MoE machine! + if (warp_id < e_num_rdma_ranks) { + const int dst_rdma_rank = warp_id + e_start_rdma_rank; + const int dst_rdma_expert_start = dst_rdma_rank * kNumRdmaExperts; + const int dst_rdma_expert_end = (dst_rdma_rank + 1) * kNumRdmaExperts; + + const int64_t* topk_idx_now = topk_idx + token_idx * kTopk; + const float* topk_weights_now = topk_weights + token_idx * kTopk; + + const auto nvl_rank_nums = + nvl_rank_meta + dst_rdma_rank * (kTopk * 3 + 1); + const auto nvl_rank_meta_now = nvl_rank_nums + 1; + + int dst_nvl_count = 0; + for (int topk_i = 0; topk_i < kTopk; ++topk_i) { + const int64_t expert_idx = topk_idx_now[topk_i]; + const float topk_weight = topk_weights_now[topk_i]; + if (expert_idx >= dst_rdma_expert_start && + expert_idx < dst_rdma_expert_end) { + if (lane_id == 0) { + nvl_rank_meta_now[dst_nvl_count * 3] = + expert_idx % kNumRdmaExperts; // dst_expert in dst_rdma_rank + const int dst_index = + atomicAdd(&atomic_counter_per_expert[expert_idx], 1); + nvl_rank_meta_now[dst_nvl_count * 3 + 1] = + dst_index; // dst_index + reinterpret_cast( + nvl_rank_meta_now)[dst_nvl_count * 3 + 2] = topk_weight; + } + dst_nvl_count += 1; + } + } + lane_id == 0 ? (nvl_rank_nums[0] = dst_nvl_count) : 0; + __syncwarp(); + + // dst_nvl_count > 0 means should issue message to dst_rdma_rank! + if (dst_nvl_count > 0) { + lane_id == 0 ? (rdma_send_flags_now[dst_rdma_rank] = true) : 0; + int slot_idx = + lane_id == 0 + ? atomicAdd(&atomic_counter_per_rdma[dst_rdma_rank], 1) + : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); // broadcast + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = + reinterpret_cast(rdma_recv_x) + + (rdma_rank * num_max_dispatch_tokens_per_rank + slot_idx) * + num_bytes_per_msg; + + // must run in RDMA! + if constexpr (kNumQPs > 1) { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, + src_ptr, + num_bytes_per_msg, + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + qp_id, + lane_id, + 0); + } else { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, + src_ptr, + num_bytes_per_msg, + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + qp_id, + lane_id, + slot_idx); + } + __syncwarp(); + lane_id == 0 + ? (atomic_add_release_global( + atomic_finished_counter_per_rdma + dst_rdma_rank, 1)) + : 0; + } + } + } + } + if (sm_id == num_sms - 1) { + for (int i = thread_id; i < kNumLocalExperts; i += num_threads) { + packed_recv_count[i] = 0; + } + } + cg::this_grid().sync(); + + // Issue count sends + if (sm_id < kNumRdmaRanks) { + int dst_rdma_rank = sm_id; + const auto num_tokens_sent = + atomic_finished_counter_per_rdma[dst_rdma_rank]; + + if (thread_id < kNumQPs) { + auto dst_ptr = reinterpret_cast( + rdma_recv_count + rdma_rank * kNumQPs + thread_id); + + bool is_local_copy = dst_rdma_rank == rdma_rank; + if (is_local_copy) { // local copy + st_na_release(rdma_recv_count + rdma_rank * kNumQPs + thread_id, + -num_tokens_sent - 1); + } else { + nvshmemi_ibgda_amo_nonfetch_add( + reinterpret_cast(dst_ptr), + -num_tokens_sent - 1, + dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + thread_id); + } + } + __syncthreads(); + // clean + if (thread_id == 0) { + atomic_counter_per_rdma[dst_rdma_rank] = 0; + atomic_finished_counter_per_rdma[dst_rdma_rank] = 0; + } + } + if (sm_id == num_sms - 1) { + for (int i = thread_id; i < kNumExperts; i += num_threads) { + atomic_counter_per_expert[i] = 0; + } + } + +LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; + + // TODO(ZKK): only wait one rank complete, is need to wait all rank complete + if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + int e_num_rdma_rank = e_num_ranks / NUM_MAX_NVL_PEERS; + int e_start_rdma_rank = e_start_rank / NUM_MAX_NVL_PEERS; + + // ========== + const int sms_per_rdma = num_sms / kNumRdmaRanks; + const int src_rdma_rank = sm_id / sms_per_rdma; + if (src_rdma_rank < kNumRdmaRanks) { + const int sub_rdma_rank = sm_id % sms_per_rdma; + if (thread_id < kNumQPs) { + if (thread_id == 0) { + sub_rdma_rank == 0 ? packed_rdma_recv_count[src_rdma_rank] = -1 : 0; + } + } + } + + // ======== + if (thread_id < kNumExperts && sm_id == 0) { + const auto src_rank = thread_id / kNumLocalExperts; + const auto local_expert_idx = thread_id % kNumLocalExperts; + const auto recv_range = + packed_recv_layout_range + local_expert_idx * kNumRanks; + recv_range[src_rank] = pack2(0, 0); + } + + if (sm_id < e_num_rdma_rank && thread_id < NUM_MAX_NVL_PEERS) { + int src_rdma_rank = sm_id + e_start_rdma_rank; + auto lsl_flag_before = ld_acquire_sys_global( + rdma_recv_complete + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id); + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf( + "[kernel][dispatch][wait] src_rdma_rank: %d, offset: %d, " + "flag_before: %d\n", + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + lsl_flag_before); + } + } + + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while ((ld_acquire_sys_global(rdma_recv_complete + + src_rdma_rank * NUM_MAX_NVL_PEERS + + thread_id)) == 0) { + // debug info of dispatch wait + if (M2N_LL_HANG_DEBUG) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + if (thread_id == 0) { + printf( + "[kernel][dispatch][wait] wait than clock cycles: %ld, " + "flags: ", + wait_recv_cost); + for (int i = 0; i < a_num_ranks + e_num_ranks; i++) { + auto lsl_flag_debug = ld_acquire_sys_global( + rdma_recv_complete + src_rdma_rank * NUM_MAX_NVL_PEERS + i); + printf("%d, ", lsl_flag_debug); + } + printf("\n"); + start_time = clock64(); + } + // break; + } + } + } + auto lsl_flag = ld_acquire_sys_global( + rdma_recv_complete + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id); + + rdma_recv_complete[src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id] = 0; + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf( + "[kernel][dispatch][wait][complete] src_rdma_rank: %d, flag: " + "%d\n", + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + lsl_flag); + } + } + } + return; + } + + // below code are only executed by MoE machine! + + /* RDMA Receiver and NVL Sender */ + // we should guarantee data in rdma_recv_x are valid in MoE machine, by while + // checking rdma_recv_count! and then do NVL send! rdma_recv_x's shape is + // [kNumRdmaRanks, num_max_dispatch_tokens_per_rank] in unit of + // num_bytes_per_msg! rdma_recv_count's shape is [kNumRdmaRanks, kNumQPs] + + { + const int sms_per_rdma = num_sms / kNumRdmaRanks; + const int src_rdma_rank = sm_id / sms_per_rdma; + + // atomic_recv_tokens_per_rdma_expert's shape is + // [kNumRdmaRanks,kNumRdmaExperts] Now, + // atomic_recv_tokens_per_rdma_expert's shape is [kNumRdmaExperts]! + atomic_recv_tokens_per_rdma_expert = + atomic_recv_tokens_per_rdma_expert + src_rdma_rank * kNumRdmaExperts; + + if (src_rdma_rank < kNumRdmaRanks) { + const int sub_sm_id = sm_id % sms_per_rdma; + const int src_rank = src_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank; + + const int rmda_offset = + src_rdma_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto rdma_recv_x_uint8 = + reinterpret_cast(rdma_recv_x) + rmda_offset; + const auto packed_rdma_recv_x_uint8 = + reinterpret_cast(packed_rdma_recv_x) + rmda_offset; + + __shared__ int shared_num_recv_tokens[1]; + int num_recv_tokens_per_rdma = -1; + if (thread_id < kNumQPs) { + // only read flag of attn machine, if one machine is fast and one + // machine is slow, this will have hang in the last micro batch + if (src_rdma_rank >= a_start_rdma_rank && + src_rdma_rank < a_start_rdma_rank + a_num_rdma_ranks) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while ((num_recv_tokens_per_rdma = ld_acquire_sys_global( + rdma_recv_count + src_rdma_rank * kNumQPs + thread_id)) == + 0) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][dispatch][rdma_recv_count] wait than clock " + "cycles: %ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + } + + if (thread_id == 0) { + sub_sm_id == 0 + ? packed_rdma_recv_count[src_rdma_rank] = num_recv_tokens_per_rdma + : 0; + shared_num_recv_tokens[0] = -num_recv_tokens_per_rdma - 1; + } + } + __syncthreads(); + num_recv_tokens_per_rdma = shared_num_recv_tokens[0]; + + // data is valid, begin to send these tokens through nvlink! + // remember these tokens are from src_rdma_rank! + for (int rdma_recv_token_idx = sub_sm_id; + rdma_recv_token_idx < num_recv_tokens_per_rdma; + rdma_recv_token_idx += sms_per_rdma) { + const int token_offset = rdma_recv_token_idx * num_bytes_per_msg; + const auto rdma_recv_x_uint8_now = rdma_recv_x_uint8 + token_offset; + const auto packed_rdma_recv_x_uint8_now = + packed_rdma_recv_x_uint8 + token_offset; + + const auto src_data = reinterpret_cast(rdma_recv_x_uint8_now); + const auto rdma_recv_x_scales = reinterpret_cast( + reinterpret_cast(src_data) + sizeof(int4) + hidden_bytes); + const auto rdma_recv_nvl_rank_meta = reinterpret_cast( + rdma_recv_x_scales + (kUseFP8 ? kNumScales : 0)); + + // here must be rdma_rank! + const int dst_nvl_experts = + *(rdma_recv_nvl_rank_meta + rdma_rank * (kTopk * 3 + 1)); + const auto rdma_recv_nvl_rank_meta_now = + rdma_recv_nvl_rank_meta + rdma_rank * (kTopk * 3 + 1) + 1; + + // Used in combine + if (warp_id == num_warps - 1) { + UNROLLED_WARP_COPY( + UNROLL_FACTOR, + lane_id, + num_int4_per_msg, + reinterpret_cast(packed_rdma_recv_x_uint8_now), + reinterpret_cast(rdma_recv_x_uint8_now), + ld_nc_global, + st_na_global); + __syncwarp(); + } + + // nvl sender + // we need send dst_nvl_experts times for this rdma_recv_token_idx token + // using one sm! + for (int loop_nvl_expert_i = warp_id; + loop_nvl_expert_i < dst_nvl_experts; + loop_nvl_expert_i += num_warps) { + const int rdma_local_expert_idx = + rdma_recv_nvl_rank_meta_now[loop_nvl_expert_i * 3]; + const int dst_nvl_rank = rdma_local_expert_idx / kNumLocalExperts; + const int dst_nvl_local_expert = + rdma_local_expert_idx % kNumLocalExperts; + + const int rdma_local_expert_cumsum_index = + rdma_recv_nvl_rank_meta_now[loop_nvl_expert_i * 3 + 1]; + + // write to nvl_recv_x[dst_nvl_rank] + // whose‘s shape is [kNumLocalExperts, kNumRanks, + // num_max_dispatch_tokens_per_rank] in unit of + // num_int4_per_msg_rdma_revecier_and_nvl_sender! kNumRanks means for + // each expert we need to know which rank this data is from! + const auto dst_data = + reinterpret_cast(nvl_recv_x[dst_nvl_rank]) + + ((dst_nvl_local_expert * kNumRanks + src_rank) * + num_max_dispatch_tokens_per_rank + + rdma_local_expert_cumsum_index) * + num_int4_per_msg_rdma_revecier_and_nvl_sender; + + if (lane_id == 0) { + st_na_global(reinterpret_cast(dst_data), + rdma_local_expert_cumsum_index); + } + + UNROLLED_WARP_COPY(UNROLL_FACTOR, + lane_id, + num_int4_per_msg_rdma_to_nvl, + dst_data + 1, + src_data + 1, + ld_nc_global, + st_na_global); + __syncwarp(); + // we need record how many tokens are sent to different experts in + // this machine! + lane_id == 0 + ? (atomic_add_release_global( + atomic_recv_tokens_per_rdma_expert + rdma_local_expert_idx, + 1)) + : 0; + } + } + __syncthreads(); + thread_id == 0 ? (atomic_add_release_global( + atomic_nvl_sender_multi_sms + src_rdma_rank, 1)) + : 0; + if (sub_sm_id == 0 && thread_id == 0) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while (ld_acquire_global(atomic_nvl_sender_multi_sms + src_rdma_rank) != + sms_per_rdma) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][dispatch][atomic_nvl_sender_multi_sms] wait than " + "clock cycles: %ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + atomic_nvl_sender_multi_sms[src_rdma_rank] = 0; + } + __syncthreads(); + if (sub_sm_id == 0) { + // need tell nvl receive how many tokens we have send from src_rdma_rank + // machine! + for (int dst_rdma_local_expert_idx = thread_id; + dst_rdma_local_expert_idx < NUM_MAX_NVL_PEERS * kNumLocalExperts; + dst_rdma_local_expert_idx += num_threads) { + const int dst_nvl_rank = dst_rdma_local_expert_idx / kNumLocalExperts; + const int dst_nvl_local_expert = + dst_rdma_local_expert_idx % kNumLocalExperts; + + st_release_sys_global( + reinterpret_cast( + reinterpret_cast(nvl_recv_x[dst_nvl_rank]) + + NVL_BUFFER_X_BYTES) + + dst_nvl_local_expert * kNumRanks + src_rank, + -ld_acquire_global(atomic_recv_tokens_per_rdma_expert + + dst_rdma_local_expert_idx) - + 1); + // reset + *(atomic_recv_tokens_per_rdma_expert + dst_rdma_local_expert_idx) = 0; + } + for (int reset_i = thread_id; reset_i < kNumQPs; + reset_i += num_threads) { + rdma_recv_count[src_rdma_rank * kNumQPs + reset_i] = 0; + } + } + } + } + + /* NVL Receiver */ + if (responsible_expert_idx < kNumExperts) { + const auto src_rank = responsible_expert_idx / kNumLocalExperts; + const auto local_expert_idx = responsible_expert_idx % kNumLocalExperts; + // local_expert_idx receiveom src_rank! + const int recv_offset_this_warpgroup = + local_expert_idx * kNumRanks + src_rank; + + const auto nvl_recv_x_uint8 = + reinterpret_cast(nvl_recv_x[nvl_rank]) + + recv_offset_this_warpgroup * num_max_dispatch_tokens_per_rank * + num_bytes_per_msg_rdma_revecier_and_nvl_sender; + const auto recv_x_int4 = reinterpret_cast(packed_recv_x) + + local_expert_idx * kNumRanks * + num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_x_scales = + packed_recv_x_scales + local_expert_idx * kNumRanks * + num_max_dispatch_tokens_per_rank * + kNumScales; + const auto recv_src_info = + packed_recv_src_info + + local_expert_idx * kNumRanks * num_max_dispatch_tokens_per_rank; + const auto recv_range = + packed_recv_layout_range + local_expert_idx * kNumRanks; + + // Shared between sub-warps in warp groups + __shared__ int shared_num_recv_tokens[kNumWarpGroups], + shared_recv_token_begin_idx[kNumWarpGroups]; + + // Wait tokens to arrive + int num_recv_tokens, recv_token_begin_idx; + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, + "Requires more than one warp per group"); + if (sub_warp_id == 1 && lane_id == 0) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while ((num_recv_tokens = ld_acquire_sys_global( + reinterpret_cast( + reinterpret_cast(nvl_recv_x[nvl_rank]) + + NVL_BUFFER_X_BYTES) + + recv_offset_this_warpgroup)) == 0) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][dispatch][nvl_recv_x] wait than clock cycles: " + "%ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + num_recv_tokens = -num_recv_tokens - 1; + recv_token_begin_idx = + atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = + pack2(num_recv_tokens, recv_token_begin_idx); + // reset nvl_recv_token_num + *(reinterpret_cast( + reinterpret_cast(nvl_recv_x[nvl_rank]) + + NVL_BUFFER_X_BYTES) + + recv_offset_this_warpgroup) = 0; + } + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), + "r"(kNumWarpsPerGroup * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + // Copy tokens + EP_DEVICE_ASSERT(kNumScales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { + // Copy source info + const auto src_src_idx = reinterpret_cast( + nvl_recv_x_uint8 + + i * num_bytes_per_msg_rdma_revecier_and_nvl_sender); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + // Copy data + const auto src_data = reinterpret_cast( + reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = + recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(UNROLL_FACTOR, + lane_id, + hidden_int4, + dst_data, + src_data, + ld_nc_global, + st_na_global); + + // Copy scales + if (kUseFP8) { + const auto src_scales = reinterpret_cast( + reinterpret_cast(src_data) + hidden_bytes); + const auto dst_scales = + reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); + const auto scale_stride = kNumRanks * num_max_dispatch_tokens_per_rank; + auto scale_0 = + lane_id < kNumScales ? ld_nc_global(src_scales + lane_id) : 0; + auto scale_1 = (lane_id + 32) < kNumScales + ? ld_nc_global(src_scales + lane_id + 32) + : 0; + lane_id < kNumScales ? dst_scales[lane_id * scale_stride] = scale_0 + : 0.0f; + (lane_id + 32) < kNumScales + ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 + : 0.0f; + } + } + } + + // 这里为啥需要加上这个? + // 加上吧,放置出错啦! + cg::this_grid().sync(); + + // TODO(ZKK): Stuff. + if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + if (sm_id < a_num_rdma_ranks && thread_id < NUM_MAX_NVL_PEERS) { + int dst_rdma_rank = sm_id + a_start_rdma_rank; + auto dst_ptr = reinterpret_cast( + rdma_recv_complete + rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + + nvshmemi_ibgda_amo_nonfetch_add( + reinterpret_cast(dst_ptr), + 1, + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + thread_id); + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf("[kernel][dispatch][complete] dst_rank: %d, offset: %d\n", + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + } + } + } + } +} + +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + void* packed_rdma_recv_x, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* packed_rdma_recv_count, + bool* rdma_send_flags, + void* rdma_recv_x, + int* rdma_recv_count, + int* rdma_recv_complete, + void* rdma_x, + void** nvl_recv_x, + const void* x, + const int64_t* topk_idx, + const float* topk_weights, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + bool use_fp8, + void* workspace, + cudaStream_t stream, + int phases) { + constexpr int kNumMaxTopK = 8; + constexpr int kNumQPs = 32; + constexpr int NUM_WARPS = 32; + + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + sm_count = 24; + int num_warp_groups = cell_div(num_experts, sm_count); + num_warp_groups = + (num_warp_groups % 2 == 1) ? num_warp_groups + 1 : num_warp_groups; + const auto num_sms = max(sm_count, cell_div(num_experts, num_warp_groups)); + // const auto num_sms = 24; + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_rdma_experts = num_experts / num_rdma_ranks; + // Workspace checks + auto atomic_counter_per_expert = reinterpret_cast(workspace); + auto atomic_counter_per_rdma = atomic_counter_per_expert + num_experts; + auto atomic_finished_counter_per_rdma = + atomic_counter_per_rdma + num_rdma_ranks; + auto atomic_recv_tokens_per_rdma_expert = + atomic_finished_counter_per_rdma + num_rdma_ranks; + auto atomic_nvl_sender_multi_sms = + atomic_recv_tokens_per_rdma_expert + + num_rdma_ranks * num_rdma_experts; // num_rdma_ranks + auto atomic_counter_per_qp = + atomic_nvl_sender_multi_sms + num_rdma_ranks; // num_rdma_ranks * kNumQPs + EP_HOST_ASSERT((num_experts + num_rdma_ranks * 3 + num_rdma_experts + + num_rdma_ranks * kNumQPs) * + sizeof(int) <= + NUM_WORKSPACE_BYTES); + + DISPATCH_HIDDEN_SIZE( + hidden, + kHidden, + {DISPATCH_NUM_TOPK( + num_topk, + kTopk, + {DISPATCH_RDMA_RANKS( + num_rdma_ranks, + kNumRdmaRanks, + {DISPATCH_NUM_EXPERTS( + num_experts, + kNumExperts, + {DISPATCH_NUM_WARP_GROUPS(num_warp_groups, kNumWarpGroups, { + constexpr int kNumWarpsPerGroup = + NUM_WARPS / kNumWarpGroups; + assert(num_rdma_ranks <= + kNumWarpGroups * kNumWarpsPerGroup); + EP_STATIC_ASSERT( + kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, + "Too many top-k selections"); + auto dispatch_func = + use_fp8 ? dispatch_kernel + : dispatch_kernel; + SETUP_LAUNCH_CONFIG(num_sms, + kNumWarpGroups * kNumWarpsPerGroup * 32, + stream); + LAUNCH_KERNEL(&cfg, + dispatch_func, + packed_recv_x, + packed_recv_x_scales, + packed_rdma_recv_x, + packed_recv_src_info, + packed_recv_layout_range, + packed_recv_count, + packed_rdma_recv_count, + rdma_send_flags, + rdma_recv_x, + rdma_recv_count, + rdma_recv_complete, + rdma_x, + nvl_recv_x, + x, + topk_idx, + topk_weights, + atomic_counter_per_expert, + atomic_counter_per_rdma, + atomic_finished_counter_per_rdma, + atomic_recv_tokens_per_rdma_expert, + atomic_nvl_sender_multi_sms, + atomic_counter_per_qp, + num_tokens, + num_max_dispatch_tokens_per_rank, + rank, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + phases); + })})})})}); +} + +template +__global__ __launch_bounds__( + kNumWarpGroups* kNumWarpsPerGroup * 32, + 1) void combine_kernel(void* combined_x, + void* rdma_recv_x, + int* rdma_recv_flag, + void* rdma_send_x, + int* rdma_recv_complete, + void* dispatch_rdma_recv_x, + const int* dispatch_rdma_recv_count, + void** nvl_recv_buffer, + const void* x, + const int64_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + const bool* rdma_send_flags, + int* atomic_clean_flag, + int* atomic_nvl_sender_multi_sms, + int num_combined_tokens, + int hidden, + int num_topk, + int num_max_dispatch_tokens_per_rank, + int num_experts, + int rank, + int num_ranks, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + int phases) { + constexpr int UNROLL_FACTOR = kHidden / 1024; + constexpr int kNumRanks = kNumRdmaRanks * NUM_MAX_NVL_PEERS; + constexpr int kNumLocalExperts = kNumExperts / kNumRanks; + constexpr int kNumRdmaExperts = kNumLocalExperts * NUM_MAX_NVL_PEERS; + constexpr int kNumPerChannels = 128; + constexpr int kNumScales = kHidden / kNumPerChannels; + + const size_t num_bytes_per_msg_dispatch = + sizeof(int4) + + (kNumRdmaRanks * (kTopk * 3 + 1) * sizeof(int) + sizeof(int4) - 1) / + sizeof(int4) * sizeof(int4) + + (kDispatchUseFP8 ? (kHidden + kNumScales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); + const size_t num_bytes_per_msg_rdma_revecier_and_nvl_sender_dispatch = + sizeof(int4) + (kDispatchUseFP8 ? (kHidden + kNumScales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))); + + const size_t dispatch_hidden_bytes = + kHidden * + (kDispatchUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t combine_hidden_bytes = kHidden * sizeof(nv_bfloat16); + const size_t combine_hidden_int4_num = combine_hidden_bytes / sizeof(int4); + + const auto sm_id = static_cast(blockIdx.x); + const auto num_sms = static_cast(gridDim.x); + const auto thread_id = static_cast(threadIdx.x); + const auto num_threads = static_cast(blockDim.x), + num_warps = num_threads / 32; + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / kNumWarpsPerGroup; + const auto sub_warp_id = warp_id % kNumWarpsPerGroup; + const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; + int a_start_rdma_rank = a_start_rank / NUM_MAX_NVL_PEERS; + int a_num_rdma_ranks = a_num_ranks / NUM_MAX_NVL_PEERS; + int e_start_rdma_rank = e_start_rank / NUM_MAX_NVL_PEERS; + int e_num_rdma_ranks = e_num_ranks / NUM_MAX_NVL_PEERS; + + if (sm_id == 0 && thread_id == 0) { + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= kNumQPs); + } + + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + + constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; + if (sm_id == 0 && thread_id == 0) { + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= kNumQPs); + // EP_DEVICE_ASSERT(num_threads >= hidden_bf16_int4); // TODO: lzy why + } + + constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16); + const size_t DISPATCH_NVL_BUFFER_X_BYTES = + kNumLocalExperts * kNumRanks * num_max_dispatch_tokens_per_rank * + num_bytes_per_msg_rdma_revecier_and_nvl_sender_dispatch + + kNumExperts * sizeof(int); + const size_t COMBINE_NVL_BUFFER_X_BYTES = kNumRdmaExperts * kNumRdmaRanks * + num_max_dispatch_tokens_per_rank * + num_bytes_per_slot; + const size_t NVL_BUFFER_X_BYTES = + DISPATCH_NVL_BUFFER_X_BYTES + COMBINE_NVL_BUFFER_X_BYTES; + + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; + + if (M2N_LL_ACC_DEBUG) { + if (sm_id == 0 && thread_id == 0) { + if (responsible_expert_idx < num_experts) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_rdma_rank = dst_rank / NUM_MAX_NVL_PEERS; + const auto dst_nvl_rank = dst_rank % NUM_MAX_NVL_PEERS; + auto tmp = reinterpret_cast(nvl_recv_buffer[dst_nvl_rank] + + NVL_BUFFER_X_BYTES); + printf("nvl flag: "); + for (int i = 0; i < num_local_experts * num_ranks; i++) { + printf("%d, ", tmp[i]); + } + printf("\n"); + } + } + } + + /* NVL Sender */ + if (responsible_expert_idx < num_experts) { + // we will send local_expert_idx partial result to dst_rank! + // first + // we need issue them to dst_nvl_rank through nvlink! + // then rdma to dst_rdma_rank / dst_rank! + + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_rdma_rank = dst_rank / NUM_MAX_NVL_PEERS; + const auto dst_nvl_rank = dst_rank % NUM_MAX_NVL_PEERS; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + // global_rdma_expert_idx means expert_ids in range of one machine! + const auto global_rdma_expert_idx = + nvl_rank * num_local_experts + local_expert_idx; + const auto local_x = reinterpret_cast(x) + + local_expert_idx * num_ranks * + num_max_dispatch_tokens_per_rank * + hidden_bf16_int4; + const auto local_src_info = + src_info + + local_expert_idx * num_ranks * + num_max_dispatch_tokens_per_rank; // [dst_rank_index_source, + // dst_rdma_index, topk_weight] + const auto layout = + __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); + + // Unpack layout + int offset, num_tokens_to_send; + unpack2(layout, num_tokens_to_send, offset); + + // Attention 卡上当然要是鸡蛋啦! + // if (rank >= 0 && rank < 16) EP_DEVICE_ASSERT(num_tokens_to_send == 0); + + for (int token_idx = sub_warp_id; token_idx < num_tokens_to_send; + token_idx += kNumWarpsPerGroup) { + const int idx_now = token_idx + offset; + const int* src_idxs = local_src_info + idx_now; + const int dst_rdma_index = src_idxs[0]; + // nvl recv buffer + const auto dst_ptr = reinterpret_cast( + reinterpret_cast(nvl_recv_buffer[dst_nvl_rank]) + + DISPATCH_NVL_BUFFER_X_BYTES + + ((global_rdma_expert_idx * kNumRdmaRanks + dst_rdma_rank) * + num_max_dispatch_tokens_per_rank + + dst_rdma_index) * + num_bytes_per_slot); + const auto x_int4 = local_x + idx_now * hidden_bf16_int4; + UNROLLED_WARP_COPY(7, + lane_id, + hidden_bf16_int4, + dst_ptr, + x_int4, + ld_nc_global, + st_na_global); + __syncwarp(); + } + // Put nvl finished flag + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, + "Requires more than one warp per group"); + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 1), + "r"(kNumWarpsPerGroup * 32)); + if (sub_warp_id == 1 && lane_id == 0) { + auto dst_ptr = reinterpret_cast(reinterpret_cast( + nvl_recv_buffer[dst_nvl_rank]) + + NVL_BUFFER_X_BYTES) + + global_rdma_expert_idx * kNumRdmaRanks + dst_rdma_rank; + st_release_sys_global(dst_ptr, 1); + } + __syncwarp(); + } + + // Wait all nvl ranks to arrive + if (responsible_expert_idx < num_experts) { + EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, + "Invalid number of warps per group"); + if (rdma_rank >= e_start_rdma_rank && + rdma_rank < e_start_rdma_rank + e_num_rdma_ranks && sub_warp_id == 0 && + lane_id == 0) { + // if (sub_warp_id == 0 && lane_id == 0) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while (ld_acquire_sys_global( + reinterpret_cast( + reinterpret_cast(nvl_recv_buffer[nvl_rank]) + + NVL_BUFFER_X_BYTES) + + responsible_expert_idx) == 0) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][combine][nvl_recv_buffer] wait than clock cycles: " + "%ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + // reset nvl_recv_buffer + *(reinterpret_cast( + reinterpret_cast(nvl_recv_buffer[nvl_rank]) + + NVL_BUFFER_X_BYTES) + + responsible_expert_idx) = 0; + } + } + cg::this_grid().sync(); + + /* NVL Receiver / NVL Reducer */ + { + // receive data from nvlink and do reduce! + // then issue the result ! + const int sms_per_rdma = num_sms / kNumRdmaRanks; + const int deal_rdma_rank = sm_id / sms_per_rdma; + if (deal_rdma_rank < kNumRdmaRanks) { + const int sub_deal_rdma_rank = sm_id % sms_per_rdma; + const int qp_id = sub_deal_rdma_rank % kNumQPs; + const int num_tokens_to_deal = + (-dispatch_rdma_recv_count[deal_rdma_rank] - 1); + const auto dispatch_rdma_recv_x_this_rdma_rank = + reinterpret_cast(dispatch_rdma_recv_x) + + deal_rdma_rank * num_max_dispatch_tokens_per_rank * + num_bytes_per_msg_dispatch; + auto rdma_send_x_this_rdma_rank = + reinterpret_cast(rdma_send_x) + + deal_rdma_rank * num_max_dispatch_tokens_per_rank * + combine_hidden_bytes; + // reduce + for (int rdma_recv_token_idx = sub_deal_rdma_rank; + rdma_recv_token_idx < num_tokens_to_deal; + rdma_recv_token_idx += sms_per_rdma) { + const auto dispatch_rdma_recv_x_now = + dispatch_rdma_recv_x_this_rdma_rank + + rdma_recv_token_idx * num_bytes_per_msg_dispatch; + const auto index_source = + reinterpret_cast(dispatch_rdma_recv_x_now)[0]; + const int* nvl_rank_meta = reinterpret_cast( + dispatch_rdma_recv_x_now + sizeof(int4) + dispatch_hidden_bytes + + (kDispatchUseFP8 ? kNumScales * sizeof(float) : 0)); + const int nvl_rank_nums = + *(nvl_rank_meta + rdma_rank * (kTopk * 3 + 1)); + const int* nvl_rank_meta_now = + nvl_rank_meta + rdma_rank * (kTopk * 3 + 1) + 1; + int4* dst_ptr = reinterpret_cast( + rdma_send_x_this_rdma_rank + index_source * combine_hidden_bytes); + float combined_values[kNumElemsPerInt4] = {0.0f}; + for (int g_id = thread_id; g_id < hidden_bf16_int4; + g_id += num_threads) { + for (int nvl_rank_idx = 0; nvl_rank_idx < nvl_rank_nums; + nvl_rank_idx += 1) { + const int dst_rdma_expert_idx = nvl_rank_meta_now[nvl_rank_idx * 3]; + const int dst_cum_index = nvl_rank_meta_now[nvl_rank_idx * 3 + 1]; + const float topk_weight = reinterpret_cast( + nvl_rank_meta_now)[nvl_rank_idx * 3 + 2]; + const int4* src_ptr = reinterpret_cast( + reinterpret_cast(nvl_recv_buffer[nvl_rank]) + + DISPATCH_NVL_BUFFER_X_BYTES + + ((dst_rdma_expert_idx * kNumRdmaRanks + deal_rdma_rank) * + num_max_dispatch_tokens_per_rank + + dst_cum_index) * + num_bytes_per_slot); + auto x_vec = ld_nc_global(src_ptr + g_id); + const auto x_bf16 = reinterpret_cast(&x_vec); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) + combined_values[j] += static_cast(x_bf16[j]) * topk_weight; + } + int4& combined_int4 = *reinterpret_cast(combined_values); + auto combined_bf16 = reinterpret_cast(&combined_values); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) + combined_bf16[j] = static_cast(combined_values[j]); + dst_ptr[g_id] = combined_int4; + } + __syncthreads(); + // issue copy to remote rdma per token + if (warp_id == 0) { + const auto src_ptr = reinterpret_cast( + rdma_send_x_this_rdma_rank + index_source * combine_hidden_bytes); + const auto dst_ptr = + reinterpret_cast(rdma_recv_x) + + (rdma_rank * num_max_dispatch_tokens_per_rank + index_source) * + combine_hidden_bytes; + if (rdma_rank == deal_rdma_rank) { + // local copy + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); + UNROLLED_WARP_COPY(UNROLL_FACTOR, + lane_id, + combine_hidden_int4_num, + dst_int4_ptr, + src_int4_ptr, + ld_nc_global, + st_na_global); + } else { + if constexpr (kNumQPs > 1) { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, + src_ptr, + combine_hidden_bytes, + deal_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + qp_id, + lane_id, + 0); + } else { + nvshmemi_ibgda_put_nbi_warp( + dst_ptr, + src_ptr, + combine_hidden_bytes, + deal_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + qp_id, + lane_id, + rdma_recv_token_idx); + } + } + __syncwarp(); + } + } + thread_id == 0 ? (atomic_add_release_global( + atomic_nvl_sender_multi_sms + deal_rdma_rank, 1)) + : 0; + // all sms reduce done + if (sub_deal_rdma_rank == 0 && thread_id == 0) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while (ld_acquire_global(atomic_nvl_sender_multi_sms + + deal_rdma_rank) != sms_per_rdma) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][combine][atomic_nvl_sender_multi_sms] wait than " + "clock cycles: %ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + atomic_nvl_sender_multi_sms[deal_rdma_rank] = 0; + } + __syncthreads(); + // set flag + if (sub_deal_rdma_rank == 0 && thread_id < kNumQPs) { + // notify remote rdma + auto dst_rdma_flag = reinterpret_cast( + rdma_recv_flag + rdma_rank * kNumQPs + thread_id); + bool is_local_copy = deal_rdma_rank == rdma_rank; + if (is_local_copy) { + st_na_release(rdma_recv_flag + rdma_rank * kNumQPs + thread_id, 1); + } else { + nvshmemi_ibgda_amo_nonfetch_add( + reinterpret_cast(dst_rdma_flag), + 1, + deal_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank, + qp_id); + } + } + } + } + +LOW_LATENCY_COMBINE_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) return; + + // TODO(ZKK): stuff. + if (rank >= e_start_rank && rank < e_start_rank + e_num_ranks) { + if (sm_id < a_num_rdma_ranks && thread_id < NUM_MAX_NVL_PEERS) { + int src_rdma_rank = sm_id + a_start_rdma_rank; + auto lsl_flag_before = + ld_acquire_sys_global(rdma_recv_complete + num_ranks + + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id); + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf( + "[kernel][combine][wait] src_rdma_rank: %d, offset: %d, " + "flag_before: %d\n", + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + num_ranks + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + lsl_flag_before); + } + } + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while ((ld_acquire_sys_global(rdma_recv_complete + num_ranks + + src_rdma_rank * NUM_MAX_NVL_PEERS + + thread_id)) == 0) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf("[kernel][combine][wait] wait than clock cycles: %ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + auto lsl_flag = + ld_acquire_sys_global(rdma_recv_complete + num_ranks + + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id); + + rdma_recv_complete[num_ranks + src_rdma_rank * NUM_MAX_NVL_PEERS + + thread_id] = 0; + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf( + "[kernel][combine][wait][complete] src_rdma_rank: %d, flag: %d\n", + src_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + lsl_flag); + } + } + } + return; + } + + /* RDMA Receiver / RDMA Reducer */ + // Wait all rdma ranks to arrive + // only read flag of experts machine, if one machine is fast and one machine + // is slow, this will have hang in the last micro batch + + if (sm_id >= e_start_rdma_rank && + sm_id < e_start_rdma_rank + e_num_rdma_ranks && sm_id < kNumRdmaRanks) { + if (thread_id < kNumQPs) { + auto start_time = clock64(); + auto wait_recv_cost = clock64(); + while (ld_acquire_sys_global(rdma_recv_flag + sm_id * kNumQPs + + thread_id) == 0) { + if (M2N_LL_HANG_DEBUG) { + if (thread_id == 0) { + wait_recv_cost = clock64() - start_time; + if (wait_recv_cost > M2N_NUM_HANG_CYCLES) { + printf( + "[kernel][combine][rdma_recv_flag] wait than clock cycles: " + "%ld\n", + wait_recv_cost); + start_time = clock64(); + } + } + } + } + // reset + rdma_recv_flag[sm_id * kNumQPs + thread_id] = 0; + } + } + + cg::this_grid().sync(); + + for (int g_id = thread_id; g_id < hidden_bf16_int4; g_id += num_threads) { + for (int token_idx = sm_id; token_idx < num_combined_tokens; + token_idx += num_sms) { + float combined_values[kNumElemsPerInt4] = {0.0f}; + const bool* rdma_send_flags_now = + rdma_send_flags + token_idx * kNumRdmaRanks; + for (int rdma_rank_idx = 0; rdma_rank_idx < kNumRdmaRanks; + ++rdma_rank_idx) { + if (rdma_send_flags_now[rdma_rank_idx]) { + const int4* src_ptr = reinterpret_cast( + reinterpret_cast(rdma_recv_x) + + (rdma_rank_idx * num_max_dispatch_tokens_per_rank + token_idx) * + combine_hidden_bytes); + auto x_vec = ld_nc_global(src_ptr + g_id); + const auto x_bf16 = reinterpret_cast(&x_vec); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) + combined_values[j] += static_cast(x_bf16[j]); + } + } + // Write results + int4& combined_int4 = *reinterpret_cast(combined_values); + auto combined_bf16 = reinterpret_cast(&combined_values); +#pragma unroll + for (int j = 0; j < kNumElemsPerInt4; ++j) + combined_bf16[j] = static_cast(combined_values[j]); + (reinterpret_cast(combined_x) + + token_idx * hidden_bf16_int4)[g_id] = combined_int4; + } + } + + // + cg::this_grid().sync(); + + // TODO(ZKK): stuff. + if (rank >= a_start_rank && rank < a_start_rank + a_num_ranks) { + // int e_num_rdma_ranks = e_num_ranks / NUM_MAX_NVL_PEERS; + // int e_start_rdma_rank = e_start_rank / NUM_MAX_NVL_PEERS; + // int a_start_rdma_rank = a_start_rank / NUM_MAX_NVL_PEERS; + if (sm_id < e_num_rdma_ranks && thread_id < NUM_MAX_NVL_PEERS) { + int dst_rdma_rank = sm_id + e_start_rdma_rank; + auto dst_ptr = + reinterpret_cast(rdma_recv_complete + num_ranks + + rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + + nvshmemi_ibgda_amo_nonfetch_add( + reinterpret_cast(dst_ptr), + 1, + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + thread_id); + if (M2N_LL_DEBUG) { + if (thread_id == 0) { + printf("[kernel][combine][complete] dst_rank: %d, offset: %d\n", + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id, + num_ranks + rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank); + } + } + } + } +} + +void combine(void* combined_x, + void* rdma_recv_x, + int* rdma_recv_flag, + void* rdma_send_x, + int* rdma_recv_complete, + void* dispatch_rdma_recv_x, + const int* dispatch_rdma_recv_count, + void** nvl_buffer, + const void* x, // num_local_experts * num_ranks * kHidden + const int64_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + const bool* rdma_send_flags, + int* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int a_start_rank, + int a_num_ranks, + int e_start_rank, + int e_num_ranks, + void* workspace, + cudaStream_t stream, + int phases, + bool dispatch_use_fp8) { + constexpr int kNumMaxTopk = 8; + constexpr int kNumQPs = 4; + constexpr int NUM_WARPS = 32; + + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + sm_count = 24; + int num_warp_groups = cell_div(num_experts, sm_count); + num_warp_groups = + (num_warp_groups % 2 == 1) ? num_warp_groups + 1 : num_warp_groups; + const auto num_sms = max(sm_count, cell_div(num_experts, num_warp_groups)); + // const auto num_sms = 24; + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Check workspace + auto atomic_clean_flag = reinterpret_cast(workspace); + auto atomic_nvl_sender_multi_sms = atomic_clean_flag + 1; + EP_HOST_ASSERT((1 + num_rdma_ranks) * sizeof(int) <= NUM_WORKSPACE_BYTES); + EP_HOST_ASSERT(num_topk <= kNumMaxTopk); + + DISPATCH_HIDDEN_SIZE( + hidden, + kHidden, + {DISPATCH_NUM_TOPK( + num_topk, + kTopk, + {DISPATCH_RDMA_RANKS( + num_rdma_ranks, + kNumRdmaRanks, + {DISPATCH_NUM_EXPERTS( + num_experts, + kNumExperts, + {DISPATCH_NUM_WARP_GROUPS(num_warp_groups, kNumWarpGroups, { + constexpr int kNumWarpsPerGroup = + NUM_WARPS / kNumWarpGroups; + auto combine_func = dispatch_use_fp8 + ? combine_kernel + : combine_kernel; + SETUP_LAUNCH_CONFIG(num_sms, + kNumWarpGroups * kNumWarpsPerGroup * 32, + stream); + LAUNCH_KERNEL(&cfg, + combine_func, + combined_x, + rdma_recv_x, + rdma_recv_flag, + rdma_send_x, + rdma_recv_complete, + dispatch_rdma_recv_x, + dispatch_rdma_recv_count, + nvl_buffer, + x, + topk_idx, + topk_weights, + src_info, + layout_range, + rdma_send_flags, + atomic_clean_flag, + atomic_nvl_sender_multi_sms, + num_combined_tokens, + hidden, + num_topk, + num_max_dispatch_tokens_per_rank, + num_experts, + rank, + num_ranks, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + phases); + })})})})}) +} + +} // namespace m2n_ll_two_stage + +} // namespace deep_ep diff --git a/paddle/fluid/pybind/deep_ep_api.cc b/paddle/fluid/pybind/deep_ep_api.cc index b35dec6d22304..b162fb1566b1e 100644 --- a/paddle/fluid/pybind/deep_ep_api.cc +++ b/paddle/fluid/pybind/deep_ep_api.cc @@ -106,7 +106,11 @@ void BindDeepEPApi(pybind11::module *m) { .def("low_latency_dispatch_two_stage", &deep_ep::Buffer::low_latency_dispatch_two_stage_api) .def("low_latency_combine_two_stage", - &deep_ep::Buffer::low_latency_combine_two_stage_api); + &deep_ep::Buffer::low_latency_combine_two_stage_api) + .def("m2n_low_latency_dispatch_two_stage", + &deep_ep::Buffer::m2n_low_latency_dispatch_two_stage_api) + .def("m2n_low_latency_combine_two_stage", + &deep_ep::Buffer::m2n_low_latency_combine_two_stage_api); #endif } diff --git a/python/paddle/distributed/communication/deep_ep/__init__.py b/python/paddle/distributed/communication/deep_ep/__init__.py index 7576af9e00027..711a855c131c1 100644 --- a/python/paddle/distributed/communication/deep_ep/__init__.py +++ b/python/paddle/distributed/communication/deep_ep/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .buffer import Buffer + +from .buffer import Buffer, M2NBuffer from .utils import ( EventOverlap, get_event_from_calc_stream, @@ -22,6 +23,7 @@ __all__ = [ "Buffer", + "M2NBuffer", "EventOverlap", "get_event_from_calc_stream", "get_event_from_comm_stream", diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index e7138a1a6c633..946e36197096a 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -39,6 +39,19 @@ from .utils import EventOverlap +class M2NWorker: + """ + M2NWork manage asynchronous events + """ + + def __init__(self, hook=None) -> None: + self.hook = hook + + def wait(self): + if self.hook is not None: + self.hook() + + class Buffer: """ The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports: @@ -1217,3 +1230,524 @@ def low_latency_combine_two_stage( EventOverlap(event, tensors_to_record if async_finish else None), hook, ) + + def m2n_low_latency_dispatch_two_stage( + self, + x: paddle.Tensor, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + pre_allocated_result_memory, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + a_start_rank: int, + a_num_ranks: int, + e_start_rank: int, + e_num_ranks: int, + use_fp8: bool = True, + async_finish: bool = False, + return_recv_hook: bool = False, + ) -> tuple[ + tuple[paddle.Tensor, paddle.Tensor], + paddle.Tensor, + tuple, + EventOverlap, + Callable, + ]: + """ + A low-latency-two-stage implementation for dispatching with IBGDA. + This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA + (specifically, IBGDA must be enabled). + + Arguments: + x: `paddle.Tensor` with `bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are + supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. + topk_idx: `paddle.Tensor` with `int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes + are supported. `-1` indices (not selecting any expert) are supported. + topk_weights: `paddle.Tensor` with `float`, shaped as `[num_tokens, num_topk]`, only several top-k shapes + are supported. + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. + num_experts: the number of all experts. + use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, + but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. + If you not set this flag, the kernel will ensure the data's arrival. + + Returns: + recv_x: a tensor or tuple with received tokens for each expert. + With `use_fp8=True`: the first element is a `paddle.Tensor` shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `float8_e4m3fn`. + The second tensor is the corresponding scales for the first element with shape + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `float`. + Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. + With `use_fp8=False`, the result would be a tensor shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `bfloat16`. + Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, + as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). + recv_count: a tensor shaped `[num_local_experts]` with type `int`, indicating how many tokens each + expert receive. As mentioned before, not all tokens are valid in `recv_x`. + packed_rdma_recv_count: a tensor shaped `[num_rdma_ranks]` with type `int`, indicating how many tokens each + rdma_rank receive. + handle: the communication handle to be used in the `low_latency_combine` function. + event: the event after executing the kernel (valid only if `async_finish` is set). + hook: the receiving hook function (valid only if `return_recv_hook` is set). + """ + ( + packed_recv_x, + packed_recv_x_scales, + packed_recv_rdma_x, + packed_recv_count, + packed_rdma_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + rdma_send_flags, + event, + hook, + ) = self.runtime.m2n_low_latency_dispatch_two_stage( + x, + topk_idx, + topk_weights, + pre_allocated_result_memory, + num_max_dispatch_tokens_per_rank, + num_experts, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + use_fp8, + async_finish, + return_recv_hook, + ) + handle = ( + packed_recv_rdma_x, + packed_recv_src_info, + packed_recv_layout_range, + rdma_send_flags, + packed_rdma_recv_count, + num_max_dispatch_tokens_per_rank, + x.shape[1], + num_experts, + ) + tensors_to_record = ( + x, + topk_idx, + topk_weights, + packed_recv_x, + packed_recv_x_scales, + packed_recv_rdma_x, + packed_recv_count, + packed_rdma_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + rdma_send_flags, + ) + return ( + (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + EventOverlap(event, tensors_to_record if async_finish else None), + hook, + ) + + def m2n_low_latency_combine_two_stage( + self, + x: paddle.Tensor, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + handle: tuple, + a_start_rank: int, + a_num_ranks: int, + e_start_rank: int, + e_num_ranks: int, + dispatch_use_fp8: bool = False, + async_finish: bool = False, + return_recv_hook: bool = False, + out: paddle.Tensor | None = None, + ) -> tuple[paddle.Tensor, EventOverlap, Callable]: + """ + A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. + This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA + (specifically, IBGDA must be enabled). + Even for ranks in the same node, NVLink are fully disabled for simplicity. + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 + low-latency kernels' result tensor at a single moment. + + Arguments: + x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `bfloat16`, + the local calculated tokens to be sent to this original rank and reduced. + topk_idx: `[num_combined_tokens, num_topk]` with `int64`, the expert indices selected by the dispatched + tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals + to the number of dispatched tokens. + topk_weights: `[num_combined_tokens, num_topk]` with `float`, the expert weights selected by the dispatched + tokens. The received tokens will be reduced with the weights in this tensor. + handle: the communication handle given by the `dispatch` function. + dispatch_use_fp8: whether to enable FP8 casting in dispatch. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, + but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. + If you not set this flag, the kernel will ensure the data's arrival. + out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. + + Returns: + combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `bfloat16`. + event: the event after executing the kernel (valid only if `async_finish` is set). + hook: the receiving hook function (valid only if `return_recv_hook` is set). + """ + ( + packed_recv_rdma_x, + src_info, + layout_range, + rdma_send_flags, + packed_rdma_recv_count, + num_max_dispatch_tokens_per_rank, + hidden, + num_experts, + ) = handle + combined_x, event, hook = ( + self.runtime.m2n_low_latency_combine_two_stage( + x, + packed_recv_rdma_x, + topk_idx, + topk_weights, + src_info, + layout_range, + rdma_send_flags, + packed_rdma_recv_count, + num_max_dispatch_tokens_per_rank, + num_experts, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + dispatch_use_fp8, + async_finish, + return_recv_hook, + out, + ) + ) + tensors_to_record = ( + x, + topk_idx, + topk_weights, + src_info, + layout_range, + combined_x, + ) + return ( + combined_x, + EventOverlap(event, tensors_to_record if async_finish else None), + hook, + ) + + def m2n_get_pre_allocated_memory( + self, + num_tokens, + num_topk, + hidden, + num_max_dispatch_tokens_per_rank, + use_fp8, + ): + tmp = self.runtime.m2n_get_pre_allocated_memory( + num_tokens, + num_topk, + hidden, + num_max_dispatch_tokens_per_rank, + use_fp8, + ) + return tmp + + +class M2NBuffer: + def __init__( + self, + group: Group, + a_start_rank: int, + a_num_ranks: int, + e_start_rank: int, + e_num_ranks: int, + num_nvl_bytes: int = 0, + num_rdma_bytes: int = 0, + low_latency_mode: bool = False, + num_qps_per_rank: int = 12, + ) -> None: + self.a_start_rank = a_start_rank + self.a_num_ranks = a_num_ranks + self.e_start_rank = e_start_rank + self.e_num_ranks = e_num_ranks + self.all2all_buffer = Buffer( + group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank, + ) + + @staticmethod + def get_low_latency_rdma_size_hint_two_stage( + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_ranks: int, + a_num_ranks: int, + e_num_ranks: int, + num_experts: int, + num_topk: int, + ) -> int: + assert num_ranks == a_num_ranks + e_num_ranks + assert num_experts % e_num_ranks == 0 + m2n_num_experts = (num_experts // e_num_ranks) * ( + a_num_ranks + e_num_ranks + ) + return Buffer.get_low_latency_rdma_size_hint_two_stage( + num_max_dispatch_tokens_per_rank, + hidden, + num_ranks, + m2n_num_experts, + num_topk, + ) + + def get_low_latency_nvl_size_hint_two_stage( + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_ranks: int, + a_num_ranks: int, + e_num_ranks: int, + num_experts: int, + num_topk: int, + use_fp8: bool, + ) -> int: + assert num_ranks == a_num_ranks + e_num_ranks + assert num_experts % e_num_ranks == 0 + m2n_num_experts = (num_experts // e_num_ranks) * ( + a_num_ranks + e_num_ranks + ) + return Buffer.get_low_latency_nvl_size_hint_two_stage( + num_max_dispatch_tokens_per_rank, + hidden, + num_ranks, + m2n_num_experts, + num_topk, + use_fp8, + ) + + def m2n_get_pre_allocated_memory( + self, + num_tokens, + num_topk, + hidden, + num_max_dispatch_tokens_per_rank, + use_fp8, + ): + tmp = self.all2all_buffer.m2n_get_pre_allocated_memory( + num_tokens, + num_topk, + hidden, + num_max_dispatch_tokens_per_rank, + use_fp8, + ) + return tmp + + def a2e_isend_two_stage_v3( + self, + x: paddle.Tensor, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + pre_allocated_result_memory, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + use_fp8: bool = True, + ) -> tuple[ + tuple[paddle.Tensor, paddle.Tensor], + tuple, + EventOverlap, + Callable, + ]: + assert num_experts % self.e_num_ranks == 0 + m2n_topk_idx = topk_idx + m2n_num_experts = (num_experts // self.e_num_ranks) * ( + self.a_num_ranks + self.e_num_ranks + ) + + ( + packed_recv_x, + _, + _, + handle, + event, + hook, + ) = self.all2all_buffer.m2n_low_latency_dispatch_two_stage( + x, + m2n_topk_idx, + topk_weights, + pre_allocated_result_memory, + num_max_dispatch_tokens_per_rank, + m2n_num_experts, + self.a_start_rank, + self.a_num_ranks, + self.e_start_rank, + self.e_num_ranks, + use_fp8=use_fp8, + async_finish=True, + return_recv_hook=True, + ) + + return ( + packed_recv_x, + handle, + event, + hook, + ) + + def a2e_irecv_two_stage_v3( + self, + pre_allocated_result_memory, + hidden: int, + num_topk: int, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + use_fp8: bool = True, + ) -> tuple[ + tuple[paddle.Tensor, paddle.Tensor], + paddle.Tensor, + tuple, + EventOverlap, + Callable, + ]: + x = paddle.empty((0, hidden), dtype="bfloat16") + + topk_idx = paddle.empty( + (0, num_topk), + dtype='int64', + ) + + topk_weights = paddle.empty( + (0, num_topk), + dtype="float32", + ) + + assert num_experts % self.e_num_ranks == 0 + m2n_num_experts = (num_experts // self.e_num_ranks) * ( + self.a_num_ranks + self.e_num_ranks + ) + + ( + packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + event, + hook, + ) = self.all2all_buffer.m2n_low_latency_dispatch_two_stage( + x, + topk_idx, + topk_weights, + pre_allocated_result_memory, + num_max_dispatch_tokens_per_rank, + m2n_num_experts, + self.a_start_rank, + self.a_num_ranks, + self.e_start_rank, + self.e_num_ranks, + use_fp8=use_fp8, + async_finish=True, + return_recv_hook=True, + ) + + return ( + packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + event, + hook, + ) + + def e2a_isend_two_stage_v3( + self, + x: paddle.Tensor, + num_topk: int, + handle: tuple, + dispatch_use_fp8: bool = False, + out: paddle.Tensor | None = None, + ) -> tuple[EventOverlap, Callable]: + topk_idx = paddle.empty( + (0, num_topk), + dtype='int64', + ) + + topk_weights = paddle.empty( + (0, num_topk), + dtype="float32", + ) + + _, event, hook = self.all2all_buffer.m2n_low_latency_combine_two_stage( + x, + topk_idx, + topk_weights, + handle, + self.a_start_rank, + self.a_num_ranks, + self.e_start_rank, + self.e_num_ranks, + async_finish=True, + dispatch_use_fp8=dispatch_use_fp8, + return_recv_hook=True, + out=out, + ) + + return ( + event, + hook, + ) + + def e2a_irecv_two_stage_v3( + self, + topk_idx: paddle.Tensor, + topk_weights: paddle.Tensor, + handle: tuple, + dispatch_use_fp8: bool = False, + out: paddle.Tensor | None = None, + ) -> tuple[paddle.Tensor, EventOverlap, Callable]: + ( + packed_recv_rdma_x, + src_info, + layout_range, + rdma_send_flags, + packed_rdma_recv_count, + num_max_dispatch_tokens_per_rank, + hidden, + m2n_num_experts, + ) = handle + m2n_num_ranks = self.a_num_ranks + self.e_num_ranks + m2n_topk_idx = topk_idx + # TODO: only pass the check, this is not needed + x = paddle.empty( + ( + m2n_num_experts // m2n_num_ranks, + m2n_num_ranks * num_max_dispatch_tokens_per_rank, + hidden, + ), + dtype="bfloat16", + ) + combined_x, event, hook = ( + self.all2all_buffer.m2n_low_latency_combine_two_stage( + x, + m2n_topk_idx, + topk_weights, + handle, + self.a_start_rank, + self.a_num_ranks, + self.e_start_rank, + self.e_num_ranks, + async_finish=True, + dispatch_use_fp8=dispatch_use_fp8, + return_recv_hook=True, + out=out, + ) + ) + + return ( + combined_x, + event, + hook, + ) diff --git a/test/collective/test_m2n.py b/test/collective/test_m2n.py new file mode 100644 index 0000000000000..2c85f902d2046 --- /dev/null +++ b/test/collective/test_m2n.py @@ -0,0 +1,528 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import random + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.distributed.communication import deep_ep + +num_max_tokens = 512 + + +def bench_split( + fn1, + fn2, + fn1_wait: bool = True, + fn2_wait: bool = True, + num_warmups: int = 50, + num_tests: int = 50, +): + # clear + cache = paddle.empty((int(256e6 // 4),), dtype="int32") + cache.zero_() + + # Warmup + for _ in range(num_warmups): + dist.barrier() + req = fn1() + if fn1_wait: + req.wait() + dist.barrier() + req = fn2() + if fn2_wait: + req.wait() + dist.barrier() + + # Flush L2 + cache.zero_() + del cache + + # Testing + start_events_fn1 = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + end_events_fn1 = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + start_events_fn2 = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + end_events_fn2 = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + for i in range(num_tests): + # Record + dist.barrier() + start_events_fn1[i].record() + req = fn1() + end_events_fn1[i].record() + if fn1_wait: + req.wait() + dist.barrier() + start_events_fn2[i].record() + req = fn2() + end_events_fn2[i].record() + if fn2_wait: + req.wait() + dist.barrier() + paddle.device.synchronize() + + times_fn1 = np.array( + [ + s.elapsed_time(e) / 1e3 + for s, e in zip(start_events_fn1, end_events_fn1) + ] + )[1:] + times_fn2 = np.array( + [ + s.elapsed_time(e) / 1e3 + for s, e in zip(start_events_fn2, end_events_fn2) + ] + )[1:] + return ( + np.average(times_fn1), + np.min(times_fn1), + np.max(times_fn1), + np.average(times_fn2), + np.min(times_fn2), + np.max(times_fn2), + ) + + +def bench_m2n(fn, num_warmups: int = 50, num_tests: int = 50): + # clear + cache = paddle.empty((int(256e6 // 4),), dtype="int32") + cache.zero_() + + # Warmup + for _ in range(num_warmups): + dist.barrier() + fn() + dist.barrier() + + # Flush L2 + cache.zero_() + del cache + + # Testing + start_events_fn = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + end_events_fn = [ + paddle.device.Event(enable_timing=True) for _ in range(num_tests) + ] + for i in range(num_tests): + dist.barrier() + start_events_fn[i].record() + fn() + end_events_fn[i].record() + dist.barrier() + paddle.device.synchronize() + + times_fn = np.array( + [ + s.elapsed_time(e) / 1e3 + for s, e in zip(start_events_fn, end_events_fn) + ] + )[1:] + return ( + np.average(times_fn), + np.min(times_fn), + np.max(times_fn), + ) + + +def per_token_cast_back(x_fp8: paddle.Tensor, x_scales: paddle.Tensor): + x_fp32 = x_fp8.to("float32").view((x_fp8.shape[0], -1, 128)) + x_scales = x_scales.view((x_fp8.shape[0], -1, 1)) + return (x_fp32 * x_scales).view(x_fp8.shape).to("bfloat16") + + +def test_main( + num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + use_fp8: bool, + rank: int, + num_ranks: int, + a_start_rank: int, + a_num_ranks: int, + e_start_rank: int, + e_num_ranks: int, + group: dist.communication.group, + buffer: deep_ep.Buffer, + seed: int = 0, +): + paddle.seed(seed + rank) + random.seed(seed + rank) + + assert num_experts % e_num_ranks == 0 + num_local_experts = num_experts // e_num_ranks + num_rdma_ranks = num_ranks / 8 + + # NOTES: the integers greater than 256 exceeds the BF16 precision limit + rank_offset = 128 + assert num_ranks - rank_offset < 257, ( + 'Too many ranks (exceeding test precision limit)' + ) + + x = paddle.ones((num_tokens, hidden), dtype="bfloat16") * ( + rank - rank_offset + ) + # x[:, -128:] = paddle.arange(0, num_tokens, dtype="bfloat16").view((-1, 1)) + # x = paddle.randn((num_tokens, hidden), dtype="bfloat16") + # x = paddle.ones((num_tokens, hidden), dtype="bfloat16") * 3 + topk_idx = paddle.randint( + 0, num_experts, shape=[num_tokens, num_topk], dtype="int64" + ) + print(f"rank: {rank}, num_local_experts: {num_local_experts}") + topk_weights = paddle.randn((num_tokens, num_topk), dtype="float32").abs_() + # topk_weights = paddle.ones((num_tokens, num_topk), dtype="float32") * 5 + print("x: ", x, flush=True) + print("topk_idx: ", topk_idx, flush=True) + print("topk_weights: ", topk_weights, flush=True) + + # Calculate bandwidth + num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 + for i in range(num_tokens): + num_selections = (topk_idx[i] != -1).sum().item() + num_dispatch_comm_bytes += num_fp8_bytes * num_selections + num_combine_comm_bytes += num_bf16_bytes * num_selections + + paddle.device.synchronize() + dist.barrier() + run_time = 1 + print("run_time: ", run_time) + print("num_experts: ", num_experts) + + ref_recv_x = paddle.zeros( + (e_num_ranks, num_local_experts, hidden), dtype=paddle.float32 + ) # [8, 3, 128] + gbl_recv_x = paddle.zeros( + (e_num_ranks, num_local_experts, hidden), dtype=paddle.float32 + ) # [8, 3, 128] + ref_combin_x = paddle.zeros( + (num_tokens, hidden), dtype=paddle.float32 + ) # [96, 8192] + gbl_combin_x = paddle.zeros( + (num_tokens, hidden), dtype=paddle.float32 + ) # [96, 8192] + + if rank >= a_start_rank and rank < a_start_rank + a_num_ranks: + if not use_fp8: + ref_recv_x.zero_() + gbl_recv_x.zero_() + ref_combin_x.zero_() + gbl_combin_x.zero_() + for i in range(num_tokens): + for k, expert_id in enumerate(topk_idx[i]): + if expert_id == -1: + continue + erank_id = expert_id // num_local_experts # 0-7 + local_expert_id = expert_id % num_local_experts # 0-2 + ref_recv_x[erank_id, local_expert_id] += x[i].to( + paddle.float32 + ) + ref_combin_x[i] += ( + x[i].to(paddle.float32) * topk_weights[i][k] + ) + + packed_recv_x, handle, event, req = buffer.a2e_isend_two_stage( + x, + topk_idx, + topk_weights, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + + req.wait() + dist.barrier() + + e2a_x, event, req = buffer.e2a_irecv_two_stage( + topk_idx, + topk_weights, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + + req.wait() + dist.barrier() + + gbl_combin_x = e2a_x.to(paddle.float32) + + def a2e_isend_func(): + packed_recv_x, handle, event, req = buffer.a2e_isend_two_stage( + x, + topk_idx, + topk_weights, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + return req + + def e2a_irecv_func(): + e2a_x, event, req = buffer.e2a_irecv_two_stage( + topk_idx, + topk_weights, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + req.wait() + return req + + avg_t_fn1, min_t_fn1, max_t_fn1, avg_t_fn2, min_t_fn2, max_t_fn2 = ( + bench_split( + a2e_isend_func, e2a_irecv_func, fn1_wait=True, fn2_wait=False + ) + ) + print( + f'[rank: {rank}][a2e_isend_two_stage] ' + f'avg_t: {avg_t_fn1 * 1e6:.2f} us, min_t: {min_t_fn1 * 1e6:.2f} us, max_t: {max_t_fn1 * 1e6:.2f} us', + flush=True, + ) + print( + f'[rank: {rank}][e2a_irecv_two_stage] ' + f'avg_t: {avg_t_fn2 * 1e6:.2f} us, min_t: {min_t_fn2 * 1e6:.2f} us, max_t: {max_t_fn2 * 1e6:.2f} us', + flush=True, + ) + + if rank >= e_start_rank and rank < e_start_rank + e_num_ranks: + ( + packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + event, + req, + ) = buffer.a2e_irecv_two_stage( + hidden, + num_topk, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + req.wait() + print( + f'[rank: {rank}, packed_recv_count: {packed_recv_count}], packed_recv_x[1]: {packed_recv_x[1]}', + flush=True, + ) + dist.barrier() + + if not use_fp8: + for local_expert_id in range(num_local_experts): + gbl_recv_x[rank - e_start_rank, local_expert_id] = ( + packed_recv_x[ + local_expert_id, : packed_recv_count[local_expert_id] + ] + .to(paddle.float32) + .sum(0) + ) + + # e2a isend + if use_fp8: + simulated_gemm_x = per_token_cast_back( + packed_recv_x[0].view((-1, hidden)), + packed_recv_x[1].contiguous().view((-1, hidden // 128)), + ).view(packed_recv_x[0].shape) + else: + simulated_gemm_x = packed_recv_x.clone() + + event, req = buffer.e2a_isend_two_stage( + simulated_gemm_x, + num_topk, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + + req.wait() + dist.barrier() + + def a2e_irecv_func(): + ( + packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + event, + req, + ) = buffer.a2e_irecv_two_stage( + hidden, + num_topk, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + # event.current_stream_wait() + req.wait() + return req + + def e2a_isend_func(): + event, req = buffer.e2a_isend_two_stage( + simulated_gemm_x, + num_topk, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + return req + + avg_t_fn1, min_t_fn1, max_t_fn1, avg_t_fn2, min_t_fn2, max_t_fn2 = ( + bench_split( + a2e_irecv_func, e2a_isend_func, fn1_wait=False, fn2_wait=True + ) + ) + print( + f'[rank: {rank}][a2e_irecv_two_stage] ' + f'avg_t: {avg_t_fn1 * 1e6:.2f} us, min_t: {min_t_fn1 * 1e6:.2f} us, max_t: {max_t_fn1 * 1e6:.2f} us', + flush=True, + ) + print( + f'[rank: {rank}][e2a_isend_two_stage] ' + f'avg_t: {avg_t_fn2 * 1e6:.2f} us, min_t: {min_t_fn2 * 1e6:.2f} us, max_t: {max_t_fn2 * 1e6:.2f} us', + flush=True, + ) + + if not use_fp8: + dist.all_reduce(ref_recv_x, group=group) + dist.all_reduce(gbl_recv_x, group=group) + assert paddle.allclose(ref_recv_x, gbl_recv_x, rtol=1e-3, atol=1e-3), ( + f"[rank: {rank}], ref_recv_x: {ref_recv_x}, gbl_recv_x: {gbl_recv_x}" + ) + print( + f"[rank: {rank}], ref_recv_x: {ref_recv_x}, gbl_recv_x: {gbl_recv_x}" + ) + assert paddle.allclose( + ref_combin_x, gbl_combin_x, rtol=1.0, atol=1.0 + ), ( + f"[rank: {rank}], ref_combin_x: {ref_combin_x}, gbl_combin_x: {gbl_combin_x}" + ) + print( + f"[rank: {rank}], ref_combin_x: {ref_combin_x}, gbl_combin_x: {gbl_combin_x}" + ) + print(f"rank: {rank} passed the check") + dist.barrier() + + +def test_loop(): + rank = dist.get_rank() + num_ranks = dist.get_world_size() + group = paddle.distributed.new_group(range(num_ranks)) + print("rank: ", rank, flush=True) + print("num_ranks: ", num_ranks, flush=True) + + a_start_rank = 0 + a_num_ranks = 16 + e_start_rank = a_start_rank + a_num_ranks + e_num_ranks = num_ranks - a_num_ranks + # 64 * 3 / 48 = 4 + # 64 * 3 / 32 = 6 + # 64 * 3 / 24 = 8 + # 64 * 3 / 12 = 16 + num_tokens, hidden, num_topk, num_experts = 96, 8192, 8, 64 + + assert num_tokens <= num_max_tokens, ( + "num_tokens must be less equal to num_max_tokens" + ) + num_rdma_ranks = num_ranks / 8 + num_local_experts = num_experts / num_ranks + num_rdma_bytes = deep_ep.M2NBuffer.get_low_latency_rdma_size_hint_two_stage( + num_max_tokens, + hidden, + num_ranks, + a_num_ranks, + e_num_ranks, + num_experts, + num_topk, + ) + + use_fp8 = True + num_nvl_bytes = deep_ep.M2NBuffer.get_low_latency_nvl_size_hint_two_stage( + num_max_tokens, + hidden, + num_ranks, + a_num_ranks, + e_num_ranks, + num_experts, + num_topk, + use_fp8, + ) + print( + f'Allocating rdma buffer size: {num_rdma_bytes / 1e6} MB, nvl buffer size: {num_nvl_bytes / 1e6} MB...', + flush=True, + ) + + buffer = deep_ep.M2NBuffer( + group, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_rdma_ranks, + ) + test_main( + num_tokens, + hidden, + num_experts, + num_topk, + use_fp8, + rank, + num_ranks, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + group, + buffer, + seed=1, + ) + + +def init_dist_env(world_size, seed=20): + context = contextlib.nullcontext() + with context: + # start to init distributed env + strategy = fleet.DistributedStrategy() + + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": world_size, + "pp_degree": 1, + "sharding_degree": 1, + } + + # Set control in tensor parallel + strategy.tensor_parallel_configs = {"tensor_init_seed": seed} + + fleet.init(is_collective=True, strategy=strategy) + + +if __name__ == '__main__': + if dist.get_world_size() > 1: + init_dist_env(dist.get_world_size()) + test_loop() diff --git a/test/collective/test_m2n_all_layers_v3.py b/test/collective/test_m2n_all_layers_v3.py new file mode 100644 index 0000000000000..b11f3da53ffbe --- /dev/null +++ b/test/collective/test_m2n_all_layers_v3.py @@ -0,0 +1,562 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import random +import time + +import paddle +import paddle.distributed as dist +from paddle import Tensor +from paddle.distributed import fleet +from paddle.distributed.communication import deep_ep +from paddle.incubate.fp8 import deep_gemm +from paddle.incubate.fp8.deep_gemm import ( + ceil_div, + get_col_major_tma_aligned_tensor, +) + +num_max_tokens = 512 + +M2N_DEBUG = False +M2N_ACC_DEBUG = False +M2N_DEVICE_SYNC = False + + +def per_token_cast_to_fp8(x: Tensor) -> tuple[Tensor, Tensor]: + assert x.dim() == 2 and x.shape[1] % 128 == 0 + m, n = x.shape + x_view = paddle.view(x, (m, -1, 128)) + x_abs = paddle.abs(x_view).astype(paddle.float32) + x_amax = paddle.amax(x_abs, axis=2) + x_amax = paddle.view(x_amax, (m, -1)) + x_amax = paddle.clip(x_amax, min=1e-4) + scaled_x = x_view * (448.0 / x_amax.unsqueeze(2)) + scaled_x_converted = paddle.view( + scaled_x.astype(paddle.float8_e4m3fn), (m, n) + ) + + x_amax_scaled = paddle.view((x_amax / 448.0), (m, -1)) + + result = (scaled_x_converted, x_amax_scaled) + return result + + +def per_block_cast_to_fp8(x: Tensor) -> tuple[Tensor, Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = paddle.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype + ) + x_padded[:m, :n] = x + x_view = paddle.view(x_padded, (-1, 128, x_padded.shape[1] // 128, 128)) + + x_abs = paddle.abs(x_view).astype(paddle.float32) + x_amax = paddle.amax(x_abs, axis=(1, 3), keepdim=True) + x_amax = paddle.clip(x_amax, min=1e-4) + x_scaled = (x_view * (448.0 / x_amax)).astype(paddle.float8_e4m3fn) + + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + paddle.view(x_amax / 448.0, (x_view.shape[0], x_view.shape[2])) + ) + + +def construct( + x: Tensor, y: Tensor +) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor], Tensor, Tensor]: + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8 + + +def per_token_cast_back(x_fp8: paddle.Tensor, x_scales: paddle.Tensor): + x_fp32 = x_fp8.to("float32").view((x_fp8.shape[0], -1, 128)) + x_scales = x_scales.view((x_fp8.shape[0], -1, 1)) + return (x_fp32 * x_scales).view(x_fp8.shape).to("bfloat16") + + +A = paddle.randn((96, 7168), dtype="bfloat16") +B = paddle.randn((7168, 7168), dtype="bfloat16") +C = paddle.randn((96, 7168), dtype="bfloat16") + +A_fp8, B_fp8 = construct(A, B) + + +def moe(x: Tensor, y: Tensor): + [paddle.matmul(x, y) for _ in range(9)] + return paddle.matmul(x, y) + + +def moe_fp8(x_fp8: Tensor, y_fp8: Tensor, out: Tensor): + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, num_sms=108) + [ + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, num_sms=108) + for i in range(9) + ] + + +def attention(x: Tensor, y: Tensor): + return moe(x, y) + + +def attention_fp8(x_fp8: Tensor, y_fp8: Tensor, out: Tensor): + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, num_sms=108) + [ + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out, num_sms=108) + for i in range(9) + ] + + +def test_main( + num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + use_fp8: bool, + rank: int, + num_ranks: int, + a_start_rank: int, + a_num_ranks: int, + e_start_rank: int, + e_num_ranks: int, + group: dist.communication.group, + buffer: deep_ep.Buffer, + seed: int = 0, +): + paddle.seed(seed + rank) + random.seed(seed + rank) + + assert num_experts % e_num_ranks == 0 + num_local_experts = num_experts // e_num_ranks + + # NOTES: the integers greater than 256 exceeds the BF16 precision limit + rank_offset = 128 + assert num_ranks - rank_offset < 257, ( + 'Too many ranks (exceeding test precision limit)' + ) + + intermediate_size = hidden # 28672 + num_micro_batches = 3 + GB = num_tokens * 3 + MB = num_tokens + num_hidden_layers = 51 + moe_layer_start_index = 0 + num_benches = -1 + + # x_fp8, y_fp8 = construct(x, y) + # m, k = x.shape + # n, k = y.shape + # out = paddle.empty((m, n), dtype=paddle.bfloat16) + + # 整体思路 + # 1. 单层循环 + # 2. 以计算index为基准,通信index进行相应的偏移 + # 3. a2e 计算放到循环的开始位置, 最后一个micro batch循环不到, 放到循环结束单独处理 + # 4. e2a 计算放到循环的结束位置, 第一micro batch循环不到,放到循环开始之前单独处理 + # 5. 只在通信index有效的位置进行通信操作 + if rank >= a_start_rank and rank < a_start_rank + a_num_ranks: + # x = + xs = [ + paddle.ones((num_tokens, hidden), dtype="bfloat16") * (i + 2) + for i in range(num_micro_batches) + ] + weights = paddle.eye(intermediate_size, hidden, dtype="bfloat16") + + topk_idx = paddle.randint( + 0, num_experts, shape=[num_tokens, num_topk], dtype="int64" + ) + print(f"rank: {rank}, num_local_experts: {num_local_experts}") + topk_weights = paddle.ones( + (num_tokens, num_topk), dtype="float32" + ).abs_() # / num_topk + + a2e_send_result = [None] * num_micro_batches + e2a_recv_result = [None] * num_micro_batches + # for i in range(num_benches): + i = -1 + while True: + paddle.device.synchronize() + dist.barrier() + i += 1 + if num_benches > 0 and i >= num_benches: + break + # x = paddle.ones((num_tokens, hidden), dtype="bfloat16") * ( + # rank + 1 + # ) + # loop + for idx in range( + moe_layer_start_index * num_micro_batches, + num_hidden_layers * num_micro_batches, + ): + a2e_layer_idx = idx // num_micro_batches # idx + a2e_mb_idx = idx % num_micro_batches # idx + + e2a_layer_idx_next = ( + idx - num_micro_batches + 2 + ) // num_micro_batches # idx - 2 + e2a_mb_idx_next = ( + idx - num_micro_batches + 2 + ) % num_micro_batches # idx - 2 + # attention + # x = attention(x, weights) # 96 28672 + xs[a2e_mb_idx] = attention(xs[a2e_mb_idx], weights) + if M2N_ACC_DEBUG: + print( + f"====== {i} compute attention {a2e_mb_idx}_{a2e_layer_idx}: {xs[a2e_mb_idx]}", + flush=True, + ) + + if M2N_DEBUG: + print( + f"====== {i} compute attention {a2e_mb_idx}_{a2e_layer_idx}: {xs[a2e_mb_idx]}", + flush=True, + ) + + # # attn 等待上一个micro batch数据接收完 + # if a2e_layer_idx_pre >= moe_layer_start_index: + # _, _, event, hook = a2e_send_result[a2e_mb_idx_pre] + # # event.current_stream_wait() + # hook() # .current_stream_wait() + # if M2N_DEVICE_SYNC: + # paddle.device.synchronize() + # if M2N_DEBUG: + # print(f"{i} dispatch send wait attention {a2e_mb_idx_pre}_{a2e_layer_idx_pre} data end", flush=True) + + # attn 每一个micro batch均发送数据 + a2e_send_result[a2e_mb_idx] = buffer.a2e_isend_two_stage_v3( + xs[a2e_mb_idx], + topk_idx, + topk_weights, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} dispatch send attention {a2e_mb_idx}_{a2e_layer_idx} data begin", + flush=True, + ) + + _, _, event, hook = a2e_send_result[a2e_mb_idx] + # event.current_stream_wait() + hook() # .current_stream_wait() + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} dispatch send wait attention {a2e_mb_idx}_{a2e_layer_idx} data end", + flush=True, + ) + + # attn 最后一层不在接收数据 + if ( + e2a_layer_idx_next >= moe_layer_start_index + and e2a_layer_idx_next < num_hidden_layers - 1 + ): + _, handle, _, _ = a2e_send_result[e2a_mb_idx_next] + e2a_recv_result[e2a_mb_idx_next] = ( + buffer.e2a_irecv_two_stage_v3( + topk_idx, + topk_weights, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + ) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} combine recv moe {e2a_mb_idx_next}_{e2a_layer_idx_next} data begin", + flush=True, + ) + + e2a_x, event, hook = e2a_recv_result[e2a_mb_idx_next] + # event.current_stream_wait() + hook() # .current_stream_wait() + # x = e2a_x + # print(f"{i} combine recv wait moe {e2a_mb_idx}_{e2a_layer_idx} data end, x: {x}", flush=True) + xs[e2a_mb_idx_next] = e2a_x + + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} combine recv wait moe {e2a_mb_idx_next}_{e2a_layer_idx_next} data end", + flush=True, + ) + + print(f"==================== {i}", flush=True) + # time.sleep(1) + + if rank >= e_start_rank and rank < e_start_rank + e_num_ranks: + weights = paddle.eye(intermediate_size, hidden, dtype="bfloat16") + a2e_recv_result = [None] * num_micro_batches + e2a_send_result = [None] * num_micro_batches + i = -1 + # for i in range(num_benches): + while True: + paddle.device.synchronize() + dist.barrier() + i += 1 + if num_benches > 0 and i >= num_benches: + break + # loop + a2e_recv_result[0] = buffer.a2e_irecv_two_stage_v3( + hidden, + num_topk, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"0 dispatch recv attention {0}_{0} data begin", flush=True + ) + + # moe 每一个micro batch 都等待数据接收完 + _, _, _, _, _, hook = a2e_recv_result[0] + # event.current_stream_wait() + hook().current_stream_wait() + + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print(f"0 dispatch recv tion {0}_{0} data end", flush=True) + + for idx in range( + moe_layer_start_index * num_micro_batches, + num_hidden_layers * num_micro_batches, + ): + a2e_layer_idx = idx // num_micro_batches + a2e_mb_idx = idx % num_micro_batches + a2e_layer_idx_next = (idx + 1) // num_micro_batches + a2e_mb_idx_next = (idx + 1) % num_micro_batches + + e2a_layer_idx = idx // num_micro_batches + e2a_mb_idx = idx % num_micro_batches + + if idx < num_hidden_layers * num_micro_batches - 1: + a2e_recv_result[a2e_mb_idx_next] = ( + buffer.a2e_irecv_two_stage_v3( + hidden, + num_topk, + num_max_tokens, + num_experts, + use_fp8=use_fp8, + ) + ) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} dispatch recv attention {a2e_mb_idx_next}_{a2e_layer_idx_next} data begin", + flush=True, + ) + + # moe 每一个micro batch 都等待数据接收完 + _, _, _, _, _, hook = a2e_recv_result[a2e_mb_idx_next] + # event.current_stream_wait() + hook() # .current_stream_wait() + + # if use_fp8: + # simulated_gemm_x = per_token_cast_back( + # packed_recv_x[0].view((-1, hidden)), + # packed_recv_x[1].contiguous().view((-1, hidden // 128)), + # ).view(packed_recv_x[0].shape) + # else: + # simulated_gemm_x = packed_recv_x.clone() + + # paddle.device.synchronize() + # print(f"dispatch recv wait attention {a2e_mb_idx}_{a2e_layer_idx} data end, packed_recv_x: {packed_recv_x}", flush=True) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} dispatch recv wait attention {a2e_mb_idx_next}_{a2e_layer_idx_next} data end", + flush=True, + ) + + moe(A, weights) + if M2N_DEBUG: + print( + f"====== {i} compute moe {a2e_mb_idx}_{a2e_layer_idx}", + flush=True, + ) + + # moe 启动发送上一个micro batch的数据 + if ( + e2a_layer_idx >= moe_layer_start_index + and e2a_layer_idx < num_hidden_layers - 1 + ): + ( + packed_recv_x, + packed_recv_count, + rdma_send_flags, + handle, + _, + _, + ) = a2e_recv_result[e2a_mb_idx] + if use_fp8: + simulated_gemm_x = per_token_cast_back( + packed_recv_x[0].view((-1, hidden)), + packed_recv_x[1] + .contiguous() + .view((-1, hidden // 128)), + ).view(packed_recv_x[0].shape) + else: + simulated_gemm_x = packed_recv_x + e2a_send_result[e2a_mb_idx] = buffer.e2a_isend_two_stage_v3( + simulated_gemm_x, + num_topk, + handle, + dispatch_use_fp8=use_fp8, + out=None, + ) + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} combine send moe {e2a_mb_idx}_{e2a_layer_idx} data begin", + flush=True, + ) + + if M2N_ACC_DEBUG: + print( + f"{i} combine send moe {e2a_mb_idx}_{e2a_layer_idx} data begin, simulated_gemm_x: {simulated_gemm_x}", + flush=True, + ) + + event, hook = e2a_send_result[e2a_mb_idx] + # event.current_stream_wait() + hook() # .current_stream_wait() + if M2N_DEVICE_SYNC: + paddle.device.synchronize() + if M2N_DEBUG: + print( + f"{i} combine send wait moe {e2a_mb_idx}_{e2a_layer_idx} data end", + flush=True, + ) + + # recv_count = packed_recv_count[0] + # num_valid_tokens = recv_count.item() + # moe(simulated_gemm_x[0][:num_valid_tokens], weights) + + print(f"==================== {i}", flush=True) + time.sleep(10) + # dist.barrier() + + +def test_loop(): + rank = dist.get_rank() + num_ranks = dist.get_world_size() + group = paddle.distributed.new_group(range(num_ranks)) + print("rank: ", rank, flush=True) + print("num_ranks: ", num_ranks, flush=True) + + a_start_rank = 0 + a_num_ranks = 8 + e_start_rank = a_start_rank + a_num_ranks + e_num_ranks = num_ranks - a_num_ranks + + num_tokens, hidden, num_topk, num_experts = 96, 7168, 8, 64 + + assert num_tokens <= num_max_tokens, ( + "num_tokens must be less equal to num_max_tokens" + ) + num_rdma_ranks = num_ranks / 8 + num_local_experts = num_experts / num_ranks + num_rdma_bytes = deep_ep.M2NBuffer.get_low_latency_rdma_size_hint_two_stage( + num_max_tokens, + hidden, + num_ranks, + a_num_ranks, + e_num_ranks, + num_experts, + num_topk, + ) + + use_fp8 = False + num_nvl_bytes = deep_ep.M2NBuffer.get_low_latency_nvl_size_hint_two_stage( + num_max_tokens, + hidden, + num_ranks, + a_num_ranks, + e_num_ranks, + num_experts, + num_topk, + use_fp8, + ) + print( + f'Allocating rdma buffer size: {num_rdma_bytes / 1e6} MB, nvl buffer size: {num_nvl_bytes / 1e6} MB...', + flush=True, + ) + + buffer = deep_ep.M2NBuffer( + group, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_rdma_ranks, + ) + test_main( + num_tokens, + hidden, + num_experts, + num_topk, + use_fp8, + rank, + num_ranks, + a_start_rank, + a_num_ranks, + e_start_rank, + e_num_ranks, + group, + buffer, + seed=1, + ) + + +def init_dist_env(world_size, seed=20): + context = contextlib.nullcontext() + with context: + # start to init distributed env + strategy = fleet.DistributedStrategy() + + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": world_size, + "pp_degree": 1, + "sharding_degree": 1, + } + + # Set control in tensor parallel + strategy.tensor_parallel_configs = {"tensor_init_seed": seed} + + fleet.init(is_collective=True, strategy=strategy) + + +if __name__ == '__main__': + if dist.get_world_size() > 1: + init_dist_env(dist.get_world_size()) + test_loop()