diff --git a/.gitignore b/.gitignore index cdbd9555..c9cdd19c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ build/ .vscode/ */cmake-build-*/ *.egg-info +draft.md +docs/plan.md +.humanize/ \ No newline at end of file diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index bd8f23eb..a77078fb 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -151,7 +151,12 @@ Buffer::Buffer(int rank, comm_ctx = reinterpret_cast(pg) ->GetOrCreateCommContext(place, phi::distributed::CommType::ALLTOALL); - comm_stream = comm_ctx->GetStream(); + // Construct at::cuda::CUDAStream from raw cudaStream_t + cudaStream_t raw_stream = comm_ctx->GetStream(); + c10::StreamId sid = static_cast(reinterpret_cast(raw_stream)); + comm_stream.emplace(c10::Stream(c10::Stream::UNSAFE, + c10::Device(c10::DeviceType::CUDA, device_id), + sid)); calc_ctx = reinterpret_cast( reinterpret_cast(pg) ->GetDeviceContext(place, true)); @@ -204,12 +209,12 @@ Buffer::Buffer(int rank, reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream.value().stream())); } // Create 32 MiB workspace CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); - CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream.value().stream())); // MoE counter CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); @@ -286,9 +291,7 @@ torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } -cudaStream_t Buffer::get_comm_stream() const { - return comm_stream; -} + void Buffer::destroy() { EP_HOST_ASSERT(not destroyed); @@ -298,7 +301,7 @@ void Buffer::destroy() { if (num_nvl_bytes > 0) { // Barrier - intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream.value().stream()); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC @@ -413,17 +416,17 @@ Buffer::get_dispatch_layout( // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = calc_ctx->stream(); + auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); + stream_wait(comm_stream.value(), previous_event.value()); } else { - stream_wait(comm_stream, compute_stream); + stream_wait(comm_stream.value(), compute_stream); } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); @@ -443,29 +446,29 @@ Buffer::get_dispatch_layout( num_topk, num_ranks, num_experts, - comm_stream); + comm_stream.value().stream()); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(comm_stream.value()); for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { - t.record_stream(comm_stream); + t.record_stream(comm_stream.value()); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {num_tokens_per_rdma_rank}) { - to.has_value() ? to->record_stream(comm_stream) : void(); + to.has_value() ? to->record_stream(comm_stream.value()) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream); + stream_wait(compute_stream, comm_stream.value()); } // Switch back compute stream if (allocate_on_comm_stream) - deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx); return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } @@ -575,17 +578,17 @@ Buffer::intranode_dispatch(const torch::Tensor& x, // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = calc_ctx->stream(); + auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); + stream_wait(comm_stream.value(), previous_event.value()); } else { - stream_wait(comm_stream, compute_stream); + stream_wait(comm_stream.value(), compute_stream); } // Create handles (only return for non-cached mode) @@ -604,7 +607,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, // Copy rank prefix matrix and clean flags intranode::cached_notify_dispatch( - rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); + rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream.value().stream()); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -633,7 +636,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, - comm_stream, + comm_stream.value().stream(), num_channels); if (num_worst_tokens > 0) { @@ -727,7 +730,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, buffer_ptrs_gpu, rank, num_ranks, - comm_stream, + comm_stream.value().stream(), config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); @@ -735,10 +738,10 @@ Buffer::intranode_dispatch(const torch::Tensor& x, // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(comm_stream.value()); if (!skip_x_record_stream) { for (auto& t : {x, recv_x}) { - t.record_stream(comm_stream); + t.record_stream(comm_stream.value()); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -749,7 +752,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { - t.record_stream(comm_stream); + t.record_stream(comm_stream.value()); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -763,17 +766,17 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_topk_idx, recv_topk_weights, recv_x_scales}) { - to.has_value() ? to->record_stream(comm_stream) : void(); + to.has_value() ? to->record_stream(comm_stream.value()) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream); + stream_wait(compute_stream, comm_stream.value()); } // Switch back compute stream if (allocate_on_comm_stream) { - deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx); } // Return values @@ -826,17 +829,17 @@ std::tuple, std::optionalstream(); + auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); + stream_wait(comm_stream.value(), previous_event.value()); } else { - stream_wait(comm_stream, compute_stream); + stream_wait(comm_stream.value(), compute_stream); } int num_topk = 0; @@ -863,7 +866,7 @@ std::tuple, std::optional>({bias_0, bias_1}); @@ -902,37 +905,37 @@ std::tuple, std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(comm_stream.value()); if (!skip_x_record_stream) { - x.record_stream(comm_stream); + x.record_stream(comm_stream.value()); if (allocate_on_comm_stream) x.record_stream(compute_stream); } for (auto& t : {src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { - t.record_stream(comm_stream); + t.record_stream(comm_stream.value()); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { - to.has_value() ? to->record_stream(comm_stream) : void(); + to.has_value() ? to->record_stream(comm_stream.value()) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream); + stream_wait(compute_stream, comm_stream.value()); } // Switch back compute stream if (allocate_on_comm_stream) { - deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx); } return {recv_x, recv_topk_weights, event}; @@ -1068,17 +1071,17 @@ Buffer::internode_dispatch(const torch::Tensor& x, // Allocate all tensors on comm stream if set // NOTES: do not allocate tensors upfront! - auto compute_stream = calc_ctx->stream(); + auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); + stream_wait(comm_stream.value(), previous_event.value()); } else { - stream_wait(comm_stream, compute_stream); + stream_wait(comm_stream.value(), compute_stream); } // Create handles (only return for non-cached mode) @@ -1116,7 +1119,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream, + comm_stream.value().stream(), config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, @@ -1157,7 +1160,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream, + comm_stream.value().stream(), config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); @@ -1264,14 +1267,14 @@ Buffer::internode_dispatch(const torch::Tensor& x, rank, num_ranks, cached_mode, - comm_stream, + comm_stream.value().stream(), num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(comm_stream.value()); for (auto& t : {x, is_token_in_rank, recv_x, @@ -1279,7 +1282,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { - t.record_stream(comm_stream); + t.record_stream(comm_stream.value()); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -1301,17 +1304,17 @@ Buffer::internode_dispatch(const torch::Tensor& x, send_rdma_head, send_nvl_head, recv_src_meta}) { - to.has_value() ? to->record_stream(comm_stream) : void(); + to.has_value() ? to->record_stream(comm_stream.value()) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream); + stream_wait(compute_stream, comm_stream.value()); } // Switch back compute stream if (allocate_on_comm_stream) { - deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx); } // Return values @@ -1386,17 +1389,17 @@ std::tuple, std::optionalstream(); + auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream, previous_event.value()); + stream_wait(comm_stream.value(), previous_event.value()); } else { - stream_wait(comm_stream, compute_stream); + stream_wait(comm_stream.value(), compute_stream); } // Top-k checks @@ -1436,7 +1439,7 @@ std::tuple, std::optional, std::optional event; if (async) { - event = EventHandle(comm_stream); + event = EventHandle(comm_stream.value()); for (auto& t : {x, src_meta, is_combined_token_in_rank, @@ -1499,22 +1502,22 @@ std::tuple, std::optionalrecord_stream(comm_stream) : void(); + to.has_value() ? to->record_stream(comm_stream.value()) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream); + stream_wait(compute_stream, comm_stream.value()); } // Switch back compute stream if (allocate_on_comm_stream) { - deep_ep::SetAllocatorStreamForGPUContext(compute_stream, calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(compute_stream.stream(), calc_ctx); } // Return values @@ -1609,7 +1612,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); - auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + auto launch_stream = return_recv_hook ? compute_stream : comm_stream.value(); EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); @@ -1673,7 +1676,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, use_ue8m0, workspace, num_device_sms, - launch_stream, + launch_stream.stream(), phases); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1754,7 +1757,7 @@ std::tuple, std::optional, std::optional #include +#include #include #include @@ -85,7 +86,7 @@ struct Buffer { shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication - cudaStream_t comm_stream; + std::optional comm_stream; phi::distributed::NCCLCommContext* comm_ctx; phi::GPUContext* calc_ctx; @@ -150,7 +151,14 @@ struct Buffer { torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - cudaStream_t get_comm_stream() const; + at::cuda::CUDAStream get_comm_stream() const { + return comm_stream.value(); + } + + // Helper to get raw stream for CUDA APIs + cudaStream_t get_comm_stream_raw() const { + return comm_stream.value().stream(); + } void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, @@ -314,4 +322,15 @@ inline void SetAllocatorStreamForGPUContext(gpuStream_t stream, .get()); } +// Helper to create CUDAStream from raw cudaStream_t +inline at::cuda::CUDAStream make_cuda_stream(cudaStream_t raw_stream, int device_id = -1) { + if (device_id == -1) { + CUDA_CHECK(cudaGetDevice(&device_id)); + } + c10::StreamId sid = static_cast(reinterpret_cast(raw_stream)); + return at::cuda::CUDAStream(c10::Stream(c10::Stream::UNSAFE, + c10::Device(c10::DeviceType::CUDA, device_id), + sid)); +} + } // namespace deep_ep diff --git a/csrc/event.hpp b/csrc/event.hpp index e2bc397a..449e9e49 100644 --- a/csrc/event.hpp +++ b/csrc/event.hpp @@ -15,7 +15,7 @@ struct EventHandle { event->record(at::cuda::getCurrentCUDAStream()); } - explicit EventHandle(const cudaStream_t& stream) { + explicit EventHandle(const at::cuda::CUDAStream& stream) { event = std::make_shared(torch::kCUDA); event->record(stream); } @@ -24,25 +24,25 @@ struct EventHandle { void current_stream_wait() const { CUDA_CHECK(cudaStreamWaitEvent( - at::cuda::getCurrentCUDAStream().raw_stream(), + at::cuda::getCurrentCUDAStream().stream(), event->cuda_event(), 0)); } }; -torch::Event create_event(const cudaStream_t& s) { +torch::Event create_event(const at::cuda::CUDAStream& s) { auto event = torch::Event(torch::kCUDA); event.record(s); return event; } -inline void stream_wait(const cudaStream_t& s_0, const cudaStream_t& s_1) { +inline void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { EP_HOST_ASSERT(s_0 != s_1); - CUDA_CHECK(cudaStreamWaitEvent(s_0, create_event(s_1).cuda_event(), 0)); + CUDA_CHECK(cudaStreamWaitEvent(s_0.stream(), create_event(s_1).cuda_event(), 0)); } -inline void stream_wait(const cudaStream_t& s, const EventHandle& event) { - CUDA_CHECK(cudaStreamWaitEvent(s, event.event->cuda_event(), 0)); +inline void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { + CUDA_CHECK(cudaStreamWaitEvent(s.stream(), event.event->cuda_event(), 0)); } } // namespace deep_ep diff --git a/setup.py b/setup.py index 47ec7c04..4569ff08 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def get_nvshmem_host_lib_name(base_dir): include_dirs.extend([f'{nvshmem_dir}/include']) library_dirs.extend([f'{nvshmem_dir}/lib']) nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device']) - extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib']) + extra_link_args.extend([f'-l:{nvshmem_host_lib}', f'-Wl,-rpath,{nvshmem_dir}/lib']) if int(os.getenv('DISABLE_SM90_FEATURES', 0)): # Prefer A100