From 7c8151eca89a0e59ef608e0f0e390b31376249fa Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 10 Sep 2025 14:04:29 -0700 Subject: [PATCH 1/3] Add NVSHMEM scatter cached implementation with option for static cache to reuse cache over multiple layers --- DGraph/distributed/RankLocalOps.py | 48 +++++++++- .../distributed/csrc/local_data_kernels.cuh | 52 +++++++++++ .../distributed/csrc/torch_local_bindings.cpp | 1 + .../distributed/csrc/torch_local_kernels.cu | 54 ++++++++++++ DGraph/distributed/csrc/torch_nvshmem_p2p.cu | 65 ++++++++++++++ .../csrc/torch_nvshmem_p2p_bindings.cpp | 1 + DGraph/distributed/include/torch_local.hpp | 15 +++- .../nvshmem/NVSHMEMBackendEngine.py | 49 ++++++++++- DGraph/distributed/nvshmem/_nvshmem_cache.py | 87 +++++++++++++++++++ 9 files changed, 368 insertions(+), 4 deletions(-) create mode 100644 DGraph/distributed/nvshmem/_nvshmem_cache.py diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index c4b6de0..4fc505a 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -18,7 +18,11 @@ import torch try: - from DGraph.torch_local import local_masked_gather, local_masked_scatter + from DGraph.torch_local import ( + local_masked_gather, + local_masked_scatter, + local_multi_output_scatter, + ) _LOCAL_OPT_KERNELS_AVAILABLE = True except ImportError: @@ -140,7 +144,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices - unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) + unique_rank_mapping = torch.zeros_like( + unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device + ) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping @@ -198,3 +204,41 @@ def LocalAggregateWithRemapping( local_aggregated_data.scatter_add_(1, renumbered_indices, global_data) return local_aggregated_data, new_mapping + + +def RankLocalMultiOutputScatter( + _src: torch.Tensor, + _output: torch.Tensor, + _workspace: torch.Tensor, + local_indices_slice: torch.Tensor, + rank_mapping: torch.Tensor, + cur_rank_offset: int, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + This function scatters the data from the source rank to the destination rank. + """ + if not _LOCAL_OPT_KERNELS_AVAILABLE: + raise ImportError( + "Optimized local kernels are not available. Please compile the local kernels." + ) + bs = _src.shape[0] + num_features = _src.shape[-1] + num_local_output_rows = _output.shape[1] + num_workspace_rows = _workspace.shape[1] + num_indices = local_indices_slice.shape[0] + local_multi_output_scatter( + _src, + _output, + _workspace, + local_indices_slice.cuda(), + rank_mapping.cuda(), + bs, + num_features, + num_local_output_rows, + num_workspace_rows, + num_indices, + cur_rank_offset, + rank, + ) + return _output, _workspace diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index f12ca4a..85ca995 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -157,6 +157,58 @@ namespace Local } } + __global__ void Multi_Output_Scatter_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ rank_placement, + float *__restrict__ output, + float *__restrict__ workspace, + const int mini_batch_size, + const int num_values_rows, + const int num_cols, + const int num_output_rows, + const int num_workspace_rows, + const long cur_rank_index_offset, + const int current_rank) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols * num_values_rows; + const auto output_offset = mb_i * num_cols * num_output_rows; + const auto workspace_offset = mb_i * num_cols * num_workspace_rows; + const auto ind_offset = mb_i * num_values_rows; + const auto rank_placement_offset = mb_i * num_values_rows; + + for (size_t row = gidy; row < num_values_rows; row += nthreadsy) + { + const int ind = indices[ind_offset + row]; + const int row_rank = rank_placement[rank_placement_offset + row]; + + // Determine whether to use output or workspace + auto buffer_ptr = row_rank == current_rank ? output : workspace; + const auto buffer_offset = row_rank == current_rank ? output_offset : workspace_offset; + const auto adjusted_ind = row_rank == current_rank ? ind - cur_rank_index_offset : ind; + + for (size_t i = gidx; i < num_cols; i += nthreadsx) + { + if (adjusted_ind > -1 && adjusted_ind < (row_rank == current_rank ? num_output_rows : num_workspace_rows)) + { + const auto val = values[values_offset + row * num_cols + i]; + atomicAdd(&buffer_ptr[buffer_offset + adjusted_ind * num_cols + i], Max(val, 0.0)); + } + } + } + } + } + __global__ void Rank_Local_Gather_Kernel( const float *__restrict__ values, const long *__restrict__ indices, diff --git a/DGraph/distributed/csrc/torch_local_bindings.cpp b/DGraph/distributed/csrc/torch_local_bindings.cpp index a91f516..0147a6d 100644 --- a/DGraph/distributed/csrc/torch_local_bindings.cpp +++ b/DGraph/distributed/csrc/torch_local_bindings.cpp @@ -21,4 +21,5 @@ PYBIND11_MODULE(torch_local, m) { m.def("local_masked_gather", &local_masked_gather, "Masked Gather"); m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter"); + m.def("local_multi_output_scatter", &local_multi_output_scatter, "Multi-output Scatter"); } diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index b70bf36..356dbe4 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -114,4 +114,58 @@ torch::Tensor local_masked_scatter(torch::Tensor input, rank); CUDACHECK(cudaGetLastError()); return output; +} + +torch::Tensor local_multi_output_scatter(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + torch::Tensor workspace, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows, + const int num_workspace_rows, + const long cur_rank_index_offset, + const int rank) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(rank_local_placement); + CHECK_INPUT(output); + CHECK_INPUT(workspace); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *rank_local_placement_ptr = rank_local_placement.data_ptr(); + float *output_ptr = output.data_ptr(); + float *workspace_ptr = workspace.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + // Get the default stream for the current device + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + Local::Multi_Output_Scatter_Kernel<<>>( + input_ptr, + indices_ptr, + rank_local_placement_ptr, + output_ptr, + workspace_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows, + num_workspace_rows, + cur_rank_index_offset, + rank); + CUDACHECK(cudaGetLastError()); + return output; } \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_nvshmem_p2p.cu b/DGraph/distributed/csrc/torch_nvshmem_p2p.cu index 1c88982..576e4b0 100644 --- a/DGraph/distributed/csrc/torch_nvshmem_p2p.cu +++ b/DGraph/distributed/csrc/torch_nvshmem_p2p.cu @@ -163,6 +163,71 @@ void NVSHMEMP2P::dist_put(torch::Tensor input, CUDACHECK(cudaStreamSynchronize(defaultStream)); } +void NVSHMEMP2P::dist_put_precomputed(torch::Tensor input, + torch::Tensor output, + torch::Tensor workspace, + torch::Tensor indices, + torch::Tensor rank_mappings, + torch::Tensor dst_ranks, + torch::Tensor dst_offsets, + const int num_input_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(output); + CHECK_INPUT(workspace); + CHECK_INPUT(indices); + CHECK_INPUT(rank_mappings); + CHECK_INPUT(dst_ranks); + CHECK_INPUT(dst_offsets); + + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output.is_contiguous()); + TORCH_CHECK(workspace.is_contiguous()); + TORCH_CHECK(indices.is_contiguous()); + TORCH_CHECK(rank_mappings.is_contiguous()); + TORCH_CHECK(dst_ranks.is_contiguous()); + TORCH_CHECK(dst_offsets.is_contiguous()); + + TORCH_INTERNAL_ASSERT(input.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(output.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(workspace.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(indices.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(rank_mappings.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(dst_ranks.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(dst_offsets.device().type() == at::DeviceType::CUDA); + + if (!m_initialized) + { + throw std::runtime_error("NVSHMEMP2P is not initialized"); + } + + // Get the pointers to the data + const float *input_ptr = input.data_ptr(); + const float *workspace_ptr = workspace.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *rank_mappings_ptr = rank_mappings.data_ptr(); + const long *dst_ranks_ptr = dst_ranks.data_ptr(); + const long *dst_offsets_ptr = dst_offsets.data_ptr(); + float *output_ptr = output.data_ptr(); + + const auto current_rank = NVSHMEMP2P::m_rank; + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 16; + block_dims.z = 1; + + const auto num_grids_needed = (num_input_rows + block_dims.y - 1) / block_dims.y; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.x = (num_cols + block_dims.x - 1) / block_dims.x; + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + nvshmemx_quiet_on_stream(defaultStream); + CUDACHECK(cudaStreamSynchronize(defaultStream)); + // Launch the kernel + +} void NVSHMEMP2P::dist_get(torch::Tensor input, torch::Tensor output, torch::Tensor indices, diff --git a/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp b/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp index 105d35d..5fc0e1a 100644 --- a/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp +++ b/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp @@ -23,6 +23,7 @@ PYBIND11_MODULE(torch_nvshmem_p2p, m) .def("init", &NVSHMEMP2P::init) .def("finalize", &NVSHMEMP2P::finalize) .def("dist_put", &NVSHMEMP2P::dist_put) + .def("dist_put_precomputed", &NVSHMEMP2P::dist_put_precomputed) .def("allocate_symmetric_memory", &NVSHMEMP2P::AllocateSymmetricMemory) .def("clone_tensor", &NVSHMEMP2P::clone_tensor) .def("padded_clone_tensor", &NVSHMEMP2P::padded_clone_tensor) diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index f780160..45f08ee 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -19,4 +19,17 @@ torch::Tensor local_masked_scatter(torch::Tensor input, const int num_values_rows, const int num_cols, const int num_output_rows, - const int rank); \ No newline at end of file + const int rank); + +torch::Tensor local_multi_output_scatter(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + torch::Tensor workspace, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows, + const int num_workspace_rows, + const long cur_rank_index_offset, + const int rank); \ No newline at end of file diff --git a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py index 95ed20f..d4518ef 100644 --- a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py +++ b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py @@ -17,9 +17,51 @@ import DGraph.torch_nvshmem_p2p as nvshmem import warnings from torch.autograd import Function +from DGraph.distributed.nvshmem._nvshmem_cache import NVSHMEMScatterCache +from DGraph.distributed.RankLocalOps import RankLocalMultiOutputScatter -def _nvshmmem_gather(send_tensor, indices, rank_mappings): +def _nvshmem_scatter_cache(send_tensor, cache: NVSHMEMScatterCache, workspace): + bs = send_tensor.shape[0] + num_features = send_tensor.shape[2] + + num_elem = bs * num_features * cache.num_output_rows + scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory( + num_elem, send_tensor.device.index + ) + scattered_tensor.fill_(0).float() + scattered_tensor = scattered_tensor.reshape( + (bs, cache.num_output_rows, num_features) + ) + cur_rank = nvshmem.NVSHMEMP2P.get_rank() + cur_rank_offset = int(cache.index_offsets_per_rank[cur_rank].item()) + scattered_tensor, workspace = RankLocalMultiOutputScatter( + send_tensor, + scattered_tensor, + workspace, + cache.local_indices_slice, + cache.local_dest_ranks, + cur_rank_offset, + cur_rank, + ) + # Now the scattered tesnor contains the local contributions + # The workspace contains the data to be communicated to other ranks + # but already locally accumulated + + nvshmem.NVSHMEMP2P.dist_put( + workspace, + scattered_tensor, + cache.comm_indices, + cache.local_dest_ranks, + bs, + cache.min_workspace_size, + num_features, + cache.num_output_rows, + ) + return scattered_tensor + + +def _nvshmem_gather(send_tensor, indices, rank_mappings): bs = send_tensor.shape[0] num_input_rows = send_tensor.shape[1] @@ -169,6 +211,7 @@ class NVSHMEMBackendEngine(BackendEngine): _partition_size = -1 _local_rank = -1 _partition_num = -1 + _STATIC_CACHE = None def __init__(self, ranks_per_graph=-1, *args, **kwargs): # check if already initialized @@ -221,6 +264,10 @@ def get_partition_num() -> int: def get_partition_size() -> int: return NVSHMEMBackendEngine._partition_size + def add_cache(self, cache: NVSHMEMScatterCache) -> None: + NVSHMEMBackendEngine._STATIC_CACHE = cache + return + def gather(self, input_tensor, indices, rank_mappings): assert ( len(input_tensor.shape) == 3 diff --git a/DGraph/distributed/nvshmem/_nvshmem_cache.py b/DGraph/distributed/nvshmem/_nvshmem_cache.py new file mode 100644 index 0000000..2f63c51 --- /dev/null +++ b/DGraph/distributed/nvshmem/_nvshmem_cache.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +import torch +import torch.distributed as dist + + +@dataclass +class NVSHMEMScatterCache: + """This class caches the local scatter operators and index remappings + for each rank to perform local accumulations on a workspace buffer. + """ + + num_output_rows: int + local_indices_slice: torch.Tensor + local_dest_ranks: torch.Tensor + comm_indices: torch.Tensor + rank_mapping: torch.Tensor + index_offsets_per_rank: torch.Tensor + min_workspace_size: int + rank: int + world_size: int + + +def NVSHMEMScatterCacheGenerator( + num_output_rows: int, + local_indices_slice: torch.Tensor, + local_rank_mapping: torch.Tensor, + vertices_per_rank: torch.Tensor, + rank: int, + world_size: int, +): + """ + This function generates the NVSHMEM scatter cache for each rank. It does the + following: + + 1. Computes the index offsets per rank using the vertices per rank. This is used + by the NVSHMEM scatter operation to place the data in the correct location + in the symmetric workspace buffer which is indexed by local rank. + 2. Computes the minimum workspace size needed for the NVSHMEM scatter operation. + 3. Remaps the indices that need to be communicated to other ranks to an internal + workspace index. This allows each rank to perform local accumulations of the + communicated data before doing the communication step. + """ + # Use the vertices per rank to compute the index offsets + index_offsets_per_rank = torch.zeros(world_size, dtype=torch.long) + index_offsets_per_rank[1:] = torch.cumsum(vertices_per_rank, dim=0)[:-1] + + # Find the number of unique messages from the local rank mapping + comm_mask = local_rank_mapping != rank + local_dest_ranks = local_rank_mapping[comm_mask] + local_indices_slice = local_indices_slice[comm_mask] + unique_dest_ranks, dest_indices = torch.unique( + local_dest_ranks, return_inverse=True + ) + + # The amount of communication saved with this operation is equal to the + # len(unique_dest_ranks) - len(local_dest_ranks) + + local_unique_messages = unique_dest_ranks.numel() + num_unique_messages = torch.tensor([unique_dest_ranks.numel()]) + + # This is the size of the NVSHMEM workspace buffer needed + # We need to find the maximum number of unique messages across all ranks + # because all ranks must allocate the same sized buffer + dist.all_reduce(num_unique_messages, op=dist.ReduceOp.MAX) + + global_min_workspace_size = int(num_unique_messages[0].item()) + + workspace_mapping = torch.zeros(global_min_workspace_size, dtype=torch.long) - 1 + workspace_mapping[0:local_unique_messages] = unique_dest_ranks + updated_local_indices = local_indices_slice.clone() + remap_comm_to_workspace = torch.zeros_like(local_dest_ranks) + remap_comm_to_workspace.scatter_( + 0, dest_indices, torch.arange(local_unique_messages, device=dest_indices.device) + ) + updated_local_indices[comm_mask] = remap_comm_to_workspace + + return NVSHMEMScatterCache( + num_output_rows=num_output_rows, + local_indices_slice=updated_local_indices, + local_dest_ranks=local_dest_ranks, + rank_mapping=local_rank_mapping, + comm_indices=workspace_mapping, + index_offsets_per_rank=index_offsets_per_rank, + min_workspace_size=global_min_workspace_size, + rank=rank, + world_size=world_size, + ) From 71b8bb83d3b2006390e4676ceaeef958acc39d01 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 10 Sep 2025 14:14:34 -0700 Subject: [PATCH 2/3] Remove combined p2p kernel for now (will look into merging in the future) Fix some typos + style changes --- .../distributed/csrc/torch_local_kernels.cu | 2 +- DGraph/distributed/csrc/torch_nvshmem_p2p.cu | 64 ------------------- .../csrc/torch_nvshmem_p2p_bindings.cpp | 1 - DGraph/distributed/include/torch_local.hpp | 2 +- .../nvshmem/NVSHMEMBackendEngine.py | 4 +- 5 files changed, 4 insertions(+), 69 deletions(-) diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index 356dbe4..20db91c 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -168,4 +168,4 @@ torch::Tensor local_multi_output_scatter(torch::Tensor input, rank); CUDACHECK(cudaGetLastError()); return output; -} \ No newline at end of file +} diff --git a/DGraph/distributed/csrc/torch_nvshmem_p2p.cu b/DGraph/distributed/csrc/torch_nvshmem_p2p.cu index 576e4b0..8666a3b 100644 --- a/DGraph/distributed/csrc/torch_nvshmem_p2p.cu +++ b/DGraph/distributed/csrc/torch_nvshmem_p2p.cu @@ -163,71 +163,7 @@ void NVSHMEMP2P::dist_put(torch::Tensor input, CUDACHECK(cudaStreamSynchronize(defaultStream)); } -void NVSHMEMP2P::dist_put_precomputed(torch::Tensor input, - torch::Tensor output, - torch::Tensor workspace, - torch::Tensor indices, - torch::Tensor rank_mappings, - torch::Tensor dst_ranks, - torch::Tensor dst_offsets, - const int num_input_rows, - const int num_cols, - const int num_output_rows) -{ - CHECK_INPUT(input); - CHECK_INPUT(output); - CHECK_INPUT(workspace); - CHECK_INPUT(indices); - CHECK_INPUT(rank_mappings); - CHECK_INPUT(dst_ranks); - CHECK_INPUT(dst_offsets); - - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(output.is_contiguous()); - TORCH_CHECK(workspace.is_contiguous()); - TORCH_CHECK(indices.is_contiguous()); - TORCH_CHECK(rank_mappings.is_contiguous()); - TORCH_CHECK(dst_ranks.is_contiguous()); - TORCH_CHECK(dst_offsets.is_contiguous()); - - TORCH_INTERNAL_ASSERT(input.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(output.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(workspace.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(indices.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(rank_mappings.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(dst_ranks.device().type() == at::DeviceType::CUDA); - TORCH_INTERNAL_ASSERT(dst_offsets.device().type() == at::DeviceType::CUDA); - - if (!m_initialized) - { - throw std::runtime_error("NVSHMEMP2P is not initialized"); - } - // Get the pointers to the data - const float *input_ptr = input.data_ptr(); - const float *workspace_ptr = workspace.data_ptr(); - const long *indices_ptr = indices.data_ptr(); - const long *rank_mappings_ptr = rank_mappings.data_ptr(); - const long *dst_ranks_ptr = dst_ranks.data_ptr(); - const long *dst_offsets_ptr = dst_offsets.data_ptr(); - float *output_ptr = output.data_ptr(); - - const auto current_rank = NVSHMEMP2P::m_rank; - dim3 block_dims, grid_dims; - block_dims.x = 32; - block_dims.y = 16; - block_dims.z = 1; - - const auto num_grids_needed = (num_input_rows + block_dims.y - 1) / block_dims.y; - grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; - grid_dims.x = (num_cols + block_dims.x - 1) / block_dims.x; - at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); - - nvshmemx_quiet_on_stream(defaultStream); - CUDACHECK(cudaStreamSynchronize(defaultStream)); - // Launch the kernel - -} void NVSHMEMP2P::dist_get(torch::Tensor input, torch::Tensor output, torch::Tensor indices, diff --git a/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp b/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp index 5fc0e1a..105d35d 100644 --- a/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp +++ b/DGraph/distributed/csrc/torch_nvshmem_p2p_bindings.cpp @@ -23,7 +23,6 @@ PYBIND11_MODULE(torch_nvshmem_p2p, m) .def("init", &NVSHMEMP2P::init) .def("finalize", &NVSHMEMP2P::finalize) .def("dist_put", &NVSHMEMP2P::dist_put) - .def("dist_put_precomputed", &NVSHMEMP2P::dist_put_precomputed) .def("allocate_symmetric_memory", &NVSHMEMP2P::AllocateSymmetricMemory) .def("clone_tensor", &NVSHMEMP2P::clone_tensor) .def("padded_clone_tensor", &NVSHMEMP2P::padded_clone_tensor) diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index 45f08ee..3fb8e26 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -32,4 +32,4 @@ torch::Tensor local_multi_output_scatter(torch::Tensor input, const int num_output_rows, const int num_workspace_rows, const long cur_rank_index_offset, - const int rank); \ No newline at end of file + const int rank); diff --git a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py index d4518ef..9a064f6 100644 --- a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py +++ b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py @@ -157,7 +157,7 @@ def forward(ctx, send_tensor, indices, rank_mappings): ctx.save_for_backward(indices, rank_mappings) num_rows = indices.shape[1] ctx.num_rows = num_rows - gathered_tensors = _nvshmmem_gather(send_tensor, indices, rank_mappings) + gathered_tensors = _nvshmem_gather(send_tensor, indices, rank_mappings) return gathered_tensors @staticmethod @@ -194,7 +194,7 @@ def backward(ctx, grad_output): # nvshmem_grad_output = nvshmem.register_memory(grad_output) indices, rank_mappings = ctx.saved_tensors - input_grad = _nvshmmem_gather(grad_output, indices, rank_mappings) + input_grad = _nvshmem_gather(grad_output, indices, rank_mappings) nvshmem.deregister_memory(grad_output) indices_grad = None rank_mappings_grad = None From 42a184a19d036581ca204db824948008e29938d7 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 10 Sep 2025 14:19:33 -0700 Subject: [PATCH 3/3] Add fall back option for the local aggregation --- DGraph/distributed/RankLocalOps.py | 48 ++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 4fc505a..6c21605 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -206,6 +206,41 @@ def LocalAggregateWithRemapping( return local_aggregated_data, new_mapping +def _unoptimized_RankLocalMultiOutputScatter( + _src: torch.Tensor, + _output: torch.Tensor, + _workspace: torch.Tensor, + local_indices_slice: torch.Tensor, + rank_mapping: torch.Tensor, + cur_rank_offset: int, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + This function scatters the data from the source rank to the destination rank. + """ + num_features = _src.shape[-1] + num_local_output_rows = _output.shape[1] + + local_mask = rank_mapping == rank + comm_mask = ~local_mask + local_output_indices_slice = local_indices_slice[local_mask] - cur_rank_offset + local_workspace_indices_slice = local_indices_slice[comm_mask] + + _output.scatter_add_( + 1, + local_output_indices_slice.view(1, -1, 1).expand(1, -1, num_features), + _src[:, local_mask, :], + ) + + _workspace.scatter_add_( + 1, + local_workspace_indices_slice.view(1, -1, 1).expand(1, -1, num_features), + _src[:, comm_mask, :], + ) + + return _output, _workspace + + def RankLocalMultiOutputScatter( _src: torch.Tensor, _output: torch.Tensor, @@ -219,8 +254,17 @@ def RankLocalMultiOutputScatter( This function scatters the data from the source rank to the destination rank. """ if not _LOCAL_OPT_KERNELS_AVAILABLE: - raise ImportError( - "Optimized local kernels are not available. Please compile the local kernels." + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + return _unoptimized_RankLocalMultiOutputScatter( + _src, + _output, + _workspace, + local_indices_slice, + rank_mapping, + cur_rank_offset, + rank, ) bs = _src.shape[0] num_features = _src.shape[-1]