Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions DGraph/distributed/RankLocalOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import torch

try:
from DGraph.torch_local import local_masked_gather, local_masked_scatter
from DGraph.torch_local import (
local_masked_gather,
local_masked_scatter,
local_multi_output_scatter,
)

_LOCAL_OPT_KERNELS_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -140,7 +144,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
rank_mapping = rank_mapping.to(_indices.device)
renumbered_indices = inverse_indices
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
unique_rank_mapping = torch.zeros_like(
unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device
)
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)

return renumbered_indices, unique_indices, unique_rank_mapping
Expand Down Expand Up @@ -198,3 +204,85 @@ def LocalAggregateWithRemapping(
local_aggregated_data.scatter_add_(1, renumbered_indices, global_data)

return local_aggregated_data, new_mapping


def _unoptimized_RankLocalMultiOutputScatter(
_src: torch.Tensor,
_output: torch.Tensor,
_workspace: torch.Tensor,
local_indices_slice: torch.Tensor,
rank_mapping: torch.Tensor,
cur_rank_offset: int,
rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function scatters the data from the source rank to the destination rank.
"""
num_features = _src.shape[-1]
num_local_output_rows = _output.shape[1]

local_mask = rank_mapping == rank
comm_mask = ~local_mask
local_output_indices_slice = local_indices_slice[local_mask] - cur_rank_offset
local_workspace_indices_slice = local_indices_slice[comm_mask]

_output.scatter_add_(
1,
local_output_indices_slice.view(1, -1, 1).expand(1, -1, num_features),
_src[:, local_mask, :],
)

_workspace.scatter_add_(
1,
local_workspace_indices_slice.view(1, -1, 1).expand(1, -1, num_features),
_src[:, comm_mask, :],
)

return _output, _workspace


def RankLocalMultiOutputScatter(
_src: torch.Tensor,
_output: torch.Tensor,
_workspace: torch.Tensor,
local_indices_slice: torch.Tensor,
rank_mapping: torch.Tensor,
cur_rank_offset: int,
rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function scatters the data from the source rank to the destination rank.
"""
if not _LOCAL_OPT_KERNELS_AVAILABLE:
warnings.warn(
"Optimized local kernels are not available. Falling back to the default implementation."
)
return _unoptimized_RankLocalMultiOutputScatter(
_src,
_output,
_workspace,
local_indices_slice,
rank_mapping,
cur_rank_offset,
rank,
)
bs = _src.shape[0]
num_features = _src.shape[-1]
num_local_output_rows = _output.shape[1]
num_workspace_rows = _workspace.shape[1]
num_indices = local_indices_slice.shape[0]
local_multi_output_scatter(
_src,
_output,
_workspace,
local_indices_slice.cuda(),
rank_mapping.cuda(),
bs,
num_features,
num_local_output_rows,
num_workspace_rows,
num_indices,
cur_rank_offset,
rank,
)
return _output, _workspace
52 changes: 52 additions & 0 deletions DGraph/distributed/csrc/local_data_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,58 @@ namespace Local
}
}

__global__ void Multi_Output_Scatter_Kernel(
const float *__restrict__ values,
const long *__restrict__ indices,
const long *__restrict__ rank_placement,
float *__restrict__ output,
float *__restrict__ workspace,
const int mini_batch_size,
const int num_values_rows,
const int num_cols,
const int num_output_rows,
const int num_workspace_rows,
const long cur_rank_index_offset,
const int current_rank)
{
const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x;
const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y;
const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z;

const size_t nthreadsx = gridDim.x * blockDim.x;
const size_t nthreadsy = gridDim.y * blockDim.y;
const size_t nthreadsz = gridDim.z * blockDim.z;

for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz)
{
const auto values_offset = mb_i * num_cols * num_values_rows;
const auto output_offset = mb_i * num_cols * num_output_rows;
const auto workspace_offset = mb_i * num_cols * num_workspace_rows;
const auto ind_offset = mb_i * num_values_rows;
const auto rank_placement_offset = mb_i * num_values_rows;

for (size_t row = gidy; row < num_values_rows; row += nthreadsy)
{
const int ind = indices[ind_offset + row];
const int row_rank = rank_placement[rank_placement_offset + row];

// Determine whether to use output or workspace
auto buffer_ptr = row_rank == current_rank ? output : workspace;
const auto buffer_offset = row_rank == current_rank ? output_offset : workspace_offset;
const auto adjusted_ind = row_rank == current_rank ? ind - cur_rank_index_offset : ind;

for (size_t i = gidx; i < num_cols; i += nthreadsx)
{
if (adjusted_ind > -1 && adjusted_ind < (row_rank == current_rank ? num_output_rows : num_workspace_rows))
{
const auto val = values[values_offset + row * num_cols + i];
atomicAdd(&buffer_ptr[buffer_offset + adjusted_ind * num_cols + i], Max(val, 0.0));
}
}
}
}
}

__global__ void Rank_Local_Gather_Kernel(
const float *__restrict__ values,
const long *__restrict__ indices,
Expand Down
1 change: 1 addition & 0 deletions DGraph/distributed/csrc/torch_local_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ PYBIND11_MODULE(torch_local, m)
{
m.def("local_masked_gather", &local_masked_gather, "Masked Gather");
m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter");
m.def("local_multi_output_scatter", &local_multi_output_scatter, "Multi-output Scatter");
}
56 changes: 55 additions & 1 deletion DGraph/distributed/csrc/torch_local_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,58 @@ torch::Tensor local_masked_scatter(torch::Tensor input,
rank);
CUDACHECK(cudaGetLastError());
return output;
}
}

torch::Tensor local_multi_output_scatter(torch::Tensor input,
torch::Tensor indices,
torch::Tensor rank_local_placement,
torch::Tensor output,
torch::Tensor workspace,
const int num_batches,
const int num_values_rows,
const int num_cols,
const int num_output_rows,
const int num_workspace_rows,
const long cur_rank_index_offset,
const int rank)
{
CHECK_INPUT(input);
CHECK_INPUT(indices);
CHECK_INPUT(rank_local_placement);
CHECK_INPUT(output);
CHECK_INPUT(workspace);

const float *input_ptr = input.data_ptr<float>();
const long *indices_ptr = indices.data_ptr<long>();
const long *rank_local_placement_ptr = rank_local_placement.data_ptr<long>();
float *output_ptr = output.data_ptr<float>();
float *workspace_ptr = workspace.data_ptr<float>();

dim3 block_dims, grid_dims;
block_dims.x = 32;
block_dims.y = 32;
block_dims.z = 1;

const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y;
const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x;
grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535;
grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535;
grid_dims.z = 1;
// Get the default stream for the current device
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index());
Local::Multi_Output_Scatter_Kernel<<<grid_dims, block_dims>>>(
input_ptr,
indices_ptr,
rank_local_placement_ptr,
output_ptr,
workspace_ptr,
num_batches,
num_values_rows,
num_cols,
num_output_rows,
num_workspace_rows,
cur_rank_index_offset,
rank);
CUDACHECK(cudaGetLastError());
return output;
}
1 change: 1 addition & 0 deletions DGraph/distributed/csrc/torch_nvshmem_p2p.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ void NVSHMEMP2P::dist_put(torch::Tensor input,
CUDACHECK(cudaStreamSynchronize(defaultStream));
}


void NVSHMEMP2P::dist_get(torch::Tensor input,
torch::Tensor output,
torch::Tensor indices,
Expand Down
15 changes: 14 additions & 1 deletion DGraph/distributed/include/torch_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,17 @@ torch::Tensor local_masked_scatter(torch::Tensor input,
const int num_values_rows,
const int num_cols,
const int num_output_rows,
const int rank);
const int rank);

torch::Tensor local_multi_output_scatter(torch::Tensor input,
torch::Tensor indices,
torch::Tensor rank_local_placement,
torch::Tensor output,
torch::Tensor workspace,
const int num_batches,
const int num_values_rows,
const int num_cols,
const int num_output_rows,
const int num_workspace_rows,
const long cur_rank_index_offset,
const int rank);
53 changes: 50 additions & 3 deletions DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,51 @@
import DGraph.torch_nvshmem_p2p as nvshmem
import warnings
from torch.autograd import Function
from DGraph.distributed.nvshmem._nvshmem_cache import NVSHMEMScatterCache
from DGraph.distributed.RankLocalOps import RankLocalMultiOutputScatter


def _nvshmmem_gather(send_tensor, indices, rank_mappings):
def _nvshmem_scatter_cache(send_tensor, cache: NVSHMEMScatterCache, workspace):
bs = send_tensor.shape[0]
num_features = send_tensor.shape[2]

num_elem = bs * num_features * cache.num_output_rows
scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
num_elem, send_tensor.device.index
)
scattered_tensor.fill_(0).float()
scattered_tensor = scattered_tensor.reshape(
(bs, cache.num_output_rows, num_features)
)
cur_rank = nvshmem.NVSHMEMP2P.get_rank()
cur_rank_offset = int(cache.index_offsets_per_rank[cur_rank].item())
scattered_tensor, workspace = RankLocalMultiOutputScatter(
send_tensor,
scattered_tensor,
workspace,
cache.local_indices_slice,
cache.local_dest_ranks,
cur_rank_offset,
cur_rank,
)
# Now the scattered tesnor contains the local contributions
# The workspace contains the data to be communicated to other ranks
# but already locally accumulated

nvshmem.NVSHMEMP2P.dist_put(
workspace,
scattered_tensor,
cache.comm_indices,
cache.local_dest_ranks,
bs,
cache.min_workspace_size,
num_features,
cache.num_output_rows,
)
return scattered_tensor


def _nvshmem_gather(send_tensor, indices, rank_mappings):

bs = send_tensor.shape[0]
num_input_rows = send_tensor.shape[1]
Expand Down Expand Up @@ -115,7 +157,7 @@ def forward(ctx, send_tensor, indices, rank_mappings):
ctx.save_for_backward(indices, rank_mappings)
num_rows = indices.shape[1]
ctx.num_rows = num_rows
gathered_tensors = _nvshmmem_gather(send_tensor, indices, rank_mappings)
gathered_tensors = _nvshmem_gather(send_tensor, indices, rank_mappings)
return gathered_tensors

@staticmethod
Expand Down Expand Up @@ -152,7 +194,7 @@ def backward(ctx, grad_output):
# nvshmem_grad_output =
nvshmem.register_memory(grad_output)
indices, rank_mappings = ctx.saved_tensors
input_grad = _nvshmmem_gather(grad_output, indices, rank_mappings)
input_grad = _nvshmem_gather(grad_output, indices, rank_mappings)
nvshmem.deregister_memory(grad_output)
indices_grad = None
rank_mappings_grad = None
Expand All @@ -169,6 +211,7 @@ class NVSHMEMBackendEngine(BackendEngine):
_partition_size = -1
_local_rank = -1
_partition_num = -1
_STATIC_CACHE = None

def __init__(self, ranks_per_graph=-1, *args, **kwargs):
# check if already initialized
Expand Down Expand Up @@ -221,6 +264,10 @@ def get_partition_num() -> int:
def get_partition_size() -> int:
return NVSHMEMBackendEngine._partition_size

def add_cache(self, cache: NVSHMEMScatterCache) -> None:
NVSHMEMBackendEngine._STATIC_CACHE = cache
return

def gather(self, input_tensor, indices, rank_mappings):
assert (
len(input_tensor.shape) == 3
Expand Down
Loading