Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DGraph/distributed/RankLocalOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
rank_mapping = rank_mapping.to(_indices.device)
renumbered_indices = inverse_indices
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
unique_rank_mapping = torch.zeros_like(
unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device
)
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)

return renumbered_indices, unique_indices, unique_rank_mapping
Expand Down
5 changes: 5 additions & 0 deletions DGraph/distributed/nccl/NCCLBackendEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ def __init__(self, ranks_per_graph=-1, *args, **kwargs):
if not NCCLBackendEngine._is_initialized:
self.init_process_group(ranks_per_graph)

def barrier(self) -> None:
if not dist.is_initialized():
raise RuntimeError("NCCL backend engine is not initialized")
dist.barrier()

def init_process_group(self, ranks_per_graph=-1, *args, **kwargs):
if not dist.is_initialized():
dist.init_process_group(backend="nccl", *args, **kwargs)
Expand Down
160 changes: 91 additions & 69 deletions experiments/Benchmarks/TestNCCL.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def run_scatter_benchmark(

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--message_size", type=int, default=128)
parser.add_argument("--message_size", type=int, default=2)
parser.add_argument("--benchmark_cache", action="store_true")
parser.add_argument("--num_iters", type=int, default=1000)
parser.add_argument("--log_dir", type=str, default="logs")
Expand All @@ -196,92 +196,114 @@ def main():
benchmark.print(f"Running NCCL Benchmark on {world_size} ranks")

# Built in small message benchmarks, in future we can add more
gather_graph_data = get_nccl_gather_benchmark_data(message_size, world_size, device)

benchmark.print("*" * 50)
benchmark.print("Running Gather Benchmark")
times = run_gather_benchmark(benchmark, num_iters, gather_graph_data, cache=None)

benchmark.print("Saving Gather Benchmark Times")

for i in range(world_size):
benchmark.save_np(times, f"{log_dir}/NCCL_gather_times_{i}.npy", rank_to_save=i)
for i in range(1, 20):
message_size *= 2
benchmark.print("*" * 50)
benchmark.print(f"Running NCCL Benchmark for message size {message_size}")
gather_graph_data = get_nccl_gather_benchmark_data(
message_size, world_size, device
)
dist.barrier()

benchmark.print("Gather Benchmark Complete")
benchmark.print("*" * 50)
benchmark.print("Running Gather Benchmark")
times = run_gather_benchmark(
benchmark, num_iters, gather_graph_data, cache=None
)

if benchmark_cache:
edge_placement = gather_graph_data.edge_rank_placement
edge_src_rank = gather_graph_data.edge_src_rank
indices = gather_graph_data.edge_indices
benchmark.print("Saving Gather Benchmark Times")

gather_cache = NCCLGatherCacheGenerator(
indices,
edge_placement.view(-1),
edge_src_rank.view(-1),
1,
rank,
world_size,
benchmark.save_np(
times,
f"{log_dir}/NCCL_gather_times_message_size_{message_size}"
+ f"_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Gather Benchmark Complete")
benchmark.print("*" * 50)
benchmark.print("Running Gather Benchmark with Cache")
times = run_gather_benchmark(
benchmark, num_iters, gather_graph_data, cache=gather_cache
)

benchmark.print("Saving Gather Benchmark with Cache Times")
for i in range(world_size):
if benchmark_cache:
edge_placement = gather_graph_data.edge_rank_placement
edge_src_rank = gather_graph_data.edge_src_rank
indices = gather_graph_data.edge_indices

gather_cache = NCCLGatherCacheGenerator(
indices,
edge_placement.view(-1),
edge_src_rank.view(-1),
1,
rank,
world_size,
)
benchmark.print("*" * 50)
benchmark.print("Running Gather Benchmark with Cache")
times = run_gather_benchmark(
benchmark, num_iters, gather_graph_data, cache=gather_cache
)

benchmark.print("Saving Gather Benchmark with Cache Times")
benchmark.save_np(
times, f"{log_dir}/NCCL_gather_with_cache_times_{i}.npy", rank_to_save=i
times,
f"{log_dir}/NCCL_gather_with_cache_message_size_{message_size}"
+ f"_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Gather Benchmark with Cache Complete")
benchmark.print("*" * 50)
benchmark.print("Gather Benchmark with Cache Complete")
benchmark.print("*" * 50)

scatter_graph_data = get_nccl_scatter_benchmark_data(
message_size, world_size, device
)
benchmark.print("*" * 50)
benchmark.print("Running Scatter Benchmark")
times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data, cache=None)
scatter_graph_data = get_nccl_scatter_benchmark_data(
message_size, world_size, device
)

benchmark.print("Saving Scatter Benchmark Times")
for i in range(world_size):
benchmark.save_np(
times, f"{log_dir}/NCCL_scatter_times_{i}.npy", rank_to_save=i
)
benchmark.print("*" * 50)
benchmark.print("Running Scatter Benchmark")
times = run_scatter_benchmark(
benchmark, num_iters, scatter_graph_data, cache=None
)

benchmark.print("Scatter Benchmark Complete")
benchmark.print("*" * 50)
if benchmark_cache:
edge_placement = scatter_graph_data.edge_rank_placement
edge_dest_rank = scatter_graph_data.edge_dest_rank
indices = scatter_graph_data.edge_indices

scatter_cache = NCCLScatterCacheGenerator(
indices,
edge_placement.view(-1),
edge_dest_rank.view(-1),
1,
rank,
world_size,
)
benchmark.print("*" * 50)
benchmark.print("Running Scatter Benchmark with Cache")
times = run_scatter_benchmark(
benchmark, num_iters, scatter_graph_data, cache=scatter_cache
)
benchmark.print("Saving Scatter Benchmark Times")

benchmark.print("Saving Scatter Benchmark with Cache Times")
for i in range(world_size):
benchmark.save_np(
times,
f"{log_dir}/NCCL_scatter_with_cache_times_{i}.npy",
rank_to_save=i,
f"{log_dir}/NCCL_scatter_times_message_size_{message_size}"
+ f"_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Scatter Benchmark with Cache Complete")
benchmark.print("*" * 50)
benchmark.print("Scatter Benchmark Complete")
benchmark.print("*" * 50)
if benchmark_cache:
edge_placement = scatter_graph_data.edge_rank_placement
edge_dest_rank = scatter_graph_data.edge_dest_rank
indices = scatter_graph_data.edge_indices

scatter_cache = NCCLScatterCacheGenerator(
indices,
edge_placement.view(-1),
edge_dest_rank.view(-1),
1,
rank,
world_size,
)
benchmark.print("*" * 50)
benchmark.print("Running Scatter Benchmark with Cache")
times = run_scatter_benchmark(
benchmark, num_iters, scatter_graph_data, cache=scatter_cache
)

benchmark.print("Saving Scatter Benchmark with Cache Times")

benchmark.save_np(
times,
f"{log_dir}/NCCL_scatter_with_cache_message_size_{message_size}"
+ f"_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Scatter Benchmark with Cache Complete")
benchmark.print("*" * 50)

dist.destroy_process_group()

Expand Down
48 changes: 28 additions & 20 deletions experiments/Benchmarks/TestNVSHMEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def run_scatter_benchmark(

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--message_size", type=int, default=128)
parser.add_argument("--message_size", type=int, default=2)
parser.add_argument("--benchmark_cache", action="store_true")
parser.add_argument("--num_iters", type=int, default=1000)
parser.add_argument("--log_dir", type=str, default="logs")
Expand Down Expand Up @@ -160,37 +160,45 @@ def main():
benchmark.print("*" * 50)
benchmark.print("Running Gather Benchmark")

gather_graph_data = get_nvshmem_gather_benchmark_data(
message_size, rank, world_size, device
)
times = run_gather_benchmark(benchmark, num_iters, gather_graph_data)
for i in range(1, 20):
message_size *= 2
benchmark.print(f"Running NCCL Benchmark for message size {message_size}")

gather_graph_data = get_nvshmem_gather_benchmark_data(
message_size, rank, world_size, device
)
times = run_gather_benchmark(benchmark, num_iters, gather_graph_data)

benchmark.print("Saving Gather Benchmark Times")
benchmark.print("Saving Gather Benchmark Times")

for i in range(world_size):
benchmark.save_np(
times, f"{log_dir}/NVSHMEM_gather_times_{i}.npy", rank_to_save=i
times,
f"{log_dir}/NVSHMEM_gather_times_message_size_{message_size}"
+ f"_with_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Gather Benchmark Complete")
benchmark.print("*" * 50)
benchmark.print("Gather Benchmark Complete")
benchmark.print("*" * 50)

scatter_graph_data = get_nvshmem_scatter_benchmark_data(
message_size, rank, world_size, device
)
scatter_graph_data = get_nvshmem_scatter_benchmark_data(
message_size, rank, world_size, device
)

benchmark.print("Running Scatter Benchmark")
times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data)
benchmark.print("Running Scatter Benchmark")
times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data)

benchmark.print("Saving Scatter Benchmark Times")
benchmark.print("Saving Scatter Benchmark Times")

for i in range(world_size):
benchmark.save_np(
times, f"{log_dir}/NVSHMEM_scatter_times_{i}.npy", rank_to_save=i
times,
f"{log_dir}/NVSHMEM_scatter_times_message_size_{message_size}"
+ f"_with_world_size_{world_size}.npy",
rank_to_save=0,
)

benchmark.print("Scatter Benchmark Complete")
benchmark.print("*" * 50)
benchmark.print("Scatter Benchmark Complete")
benchmark.print("*" * 50)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion experiments/Benchmarks/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def generate_cache_comparison_plot():

if __name__ == "__main__":
generate_plots("nccl")
generate_plots("nvshmem")
# generate_plots("nvshmem")
generate_cache_comparison_plot()
5 changes: 5 additions & 0 deletions experiments/GraphCast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ Run with benchmarking with the following command:
python main.py --benchmark
```
***Note: *** The graph requires a large amount of memory so better to do run on the CPU and a machine with a large amount of memory.

Run with multiple processes per GPU with the following command:
```bash
torchrun-hpc --xargs=--mpibind=off --xargs=--gpu-bind=none train_graphcast.py --is_distributed True --procs_per_gpu 4
```
4 changes: 3 additions & 1 deletion experiments/GraphCast/data_utils/graphcast_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict):
contigous_edge_mapping, renumbered_edges = torch.sort(meshtogrid_edge_placement)

src_grid_indices = src_grid_indices[renumbered_edges]
grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat)
grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat[:, 0])

for i, rank in enumerate(meshtogrid_edge_placement):
loc = src_grid_indices[i]
grid_vertex_rank_placement[loc] = rank
Expand Down Expand Up @@ -254,6 +255,7 @@ def get_mesh2grid_graph(
)

edge_features, src_mesh_indices, dst_grid_indices = m2g_graph
breakpoint()
src_mesh_indices = renumbered_vertices[src_mesh_indices]
dst_grid_indices = renumbered_grid[dst_grid_indices]

Expand Down
1 change: 1 addition & 0 deletions experiments/GraphCast/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
self.lat_lon_grid = torch.stack(
torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1
)

self.graph_cast_graph = DistributedGraphCastGraphGenerator(
self.lat_lon_grid,
mesh_level=self.mesh_level,
Expand Down
2 changes: 2 additions & 0 deletions experiments/GraphCast/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from DGraph.Communicator import Communicator
from dist_utils import SingleProcessDummyCommunicator

# class MLPSiLuWithRecompute(nn.Module):


class MeshGraphMLP(nn.Module):
"""MLP for graph processing"""
Expand Down
30 changes: 17 additions & 13 deletions experiments/GraphCast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def __init__(self, cfg: Config, comm, *args, **kwargs):
)

def forward(
self, input_grid_features: Tensor, static_graph: DistributedGraphCastGraph
self,
input_grid_features: Tensor,
static_graph: DistributedGraphCastGraph,
device: Optional[torch.device] = None,
) -> Tensor:
"""
Args:
Expand All @@ -340,18 +343,19 @@ def forward(
Returns:
(Tensor): The predicted output grid
"""

input_grid_features = input_grid_features.squeeze(0)
input_mesh_features = static_graph.mesh_graph_node_features
mesh2mesh_edge_features = static_graph.mesh_graph_edge_features
grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features
mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features
mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices
mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices
mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices
mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices
grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices
grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices
if device is None:
device = input_grid_features.device
input_grid_features = input_grid_features.squeeze(0).to(device)
input_mesh_features = static_graph.mesh_graph_node_features.to(device)
mesh2mesh_edge_features = static_graph.mesh_graph_edge_features.to(device)
grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features.to(device)
mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features.to(device)
mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices.to(device)
mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices.to(device)
mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices.to(device)
mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices.to(device)
grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices.to(device)
grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices.to(device)

out = self.embedder(
input_grid_features,
Expand Down
Loading