From 56bf43a012ebfbb84b67f86a3d1ef07b425d12d9 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 1 May 2025 09:02:09 -0700 Subject: [PATCH 1/8] Adding pre-process codes and Generator --- .../Synthetic-Billion/Generator/Generator.cpp | 290 ++++++++++++++++++ .../Synthetic-Billion/Generator/Makefile | 0 experiments/Synthetic-Billion/README.md | 43 +++ experiments/Synthetic-Billion/preprocess.py | 153 +++++++++ 4 files changed, 486 insertions(+) create mode 100644 experiments/Synthetic-Billion/Generator/Generator.cpp create mode 100644 experiments/Synthetic-Billion/Generator/Makefile create mode 100644 experiments/Synthetic-Billion/README.md create mode 100644 experiments/Synthetic-Billion/preprocess.py diff --git a/experiments/Synthetic-Billion/Generator/Generator.cpp b/experiments/Synthetic-Billion/Generator/Generator.cpp new file mode 100644 index 0000000..49b0c35 --- /dev/null +++ b/experiments/Synthetic-Billion/Generator/Generator.cpp @@ -0,0 +1,290 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// using adj_list = std::unordered_map>; +// Function to generate a random connected graph with a maximum degree constraint +// and return it as an adjacency list in an unordered_map (single-threaded). + +bool find_in_adj_list(const std::vector &adj_list, unsigned long vertex) +{ + return std::find(adj_list.begin(), adj_list.end(), vertex) != adj_list.end(); +} + +using adj_list = std::vector>; + +std::tuple generate_graph( + const unsigned long &num_vertices, + const int &max_degree) +{ + adj_list adjacency_map; + + if (num_vertices == 0) + { + std::cerr << "Warning: Number of vertices is 0. Returning empty map." << std::endl; + return std::tuple(adjacency_map, 0); + } + + if (max_degree <= 0) + { + std::cerr << "Error: Maximum degree must be greater than 0." << std::endl; + // Depending on requirements, might return empty or throw exception + return std::tuple(adjacency_map, 0); + } + + // Keep track of the current degree of each vertex. + std::vector degree(num_vertices, 0); + + adjacency_map.reserve(num_vertices); + unsigned seed = std::random_device()(); + std::mt19937_64 rng(seed); + std::uniform_int_distribution dist_vertices(0, num_vertices - 1); + + // 1. Build a random spanning tree to ensure connectivity (for num_vertices > 1). + + std::vector visited; + visited.reserve(num_vertices); + visited.push_back(0); // Start with vertex 0 + std::vector unvisited; + unvisited.reserve(num_vertices - 1); + + for (unsigned long i = 0; i < num_vertices; ++i) + { + adjacency_map.push_back(std::vector(max_degree, num_vertices)); + } + for (unsigned long i = 1; i < num_vertices; ++i) + { + unvisited.push_back(i); + } + + std::shuffle(unvisited.begin(), unvisited.end(), rng); + + long num_edges = 0; + auto num_vertex_processed = 1; + std::cout << "Building spanning tree..." << std::endl; + + for (auto u : unvisited) + { + // Randomly select a vertex from the visited set + std::uniform_int_distribution + dist_visited(0, visited.size() - 1); + unsigned long v_idx = dist_visited(rng); + unsigned long v = visited[v_idx]; + auto max_degree_check = std::max(degree[u], degree[v]) < max_degree; + + while (!max_degree_check) + { + // If the edge already exists or max degree is exceeded, try again + v_idx = dist_visited(rng); + v = visited[v_idx]; + max_degree_check = std::max(degree[u], degree[v]) < max_degree; + } + + adjacency_map[u][degree[u]] = v; + adjacency_map[v][degree[v]] = u; + + // std::cout << std::endl; + degree[u]++; + degree[v]++; + + num_edges++; + num_vertex_processed++; + + visited.push_back(u); + } + + + + + std::cout << "Number of edges in the spanning tree: " << num_edges << std::endl; + if (visited.size() != num_vertices) + { + std::cerr << "Warning: Could not connect all vertices while building spanning tree (possibly due to low max_degree). Graph might be disconnected." << std::endl; + } + + visited.clear(); + unvisited.clear(); + + // Graph is connected with all edges having at least one vertex + // and a maximum degree of max_degree. + + // Add additional random edges between vertices (will not exceed max_degree). + // Graph will of course remain connected. + + // heuristic + const auto max_new_edges = num_edges * 2; + const auto min_new_edges = num_edges / 2; + + // chose a number between min_new_edges and max_new_edges + std::uniform_int_distribution dist_edges(min_new_edges, max_new_edges); + const auto num_edges_to_add = dist_edges(rng); + + unsigned long max_possible_edges = num_vertices * (max_degree) / 4; + + unsigned long current_edges_count = num_edges; + // Attempt to add more edges for a reasonable number of times + + const auto max_possible_edges = num_vertices * (num_vertices - 1) / 2; + + std::cout << "Total edge attempts: " << num_edges_to_add << std::endl; + unsigned int attempts = 0; + while (attempts < num_edges_to_add) + { + attempts++; // Count attempt even if unsuccessful + unsigned long u = dist_vertices(rng); + unsigned long v = dist_vertices(rng); + + if (u == v) + { + continue; // No self-loops + } + + unsigned long first = std::min(u, v); + unsigned long second = std::max(u, v); + + const auto &adj_list = adjacency_map[first]; + // Check if the edge already exists + + // const auto not_in_adj_list = adj_list.find(second) == adj_list.end(); + const auto in_adj_list = find_in_adj_list(adj_list, second); + // Check if edge exists and if adding it violates max_degree + if (!in_adj_list && std::max(degree[u], degree[v]) < max_degree) + { + adjacency_map[u][degree[u]] = v; + adjacency_map[v][degree[v]] = u; // Add the edge in both directions + degree[u]++; + degree[v]++; + current_edges_count++; + } + if (current_edges_count >= max_possible_edges) + { // Stop if we reach the maximum possible edges + break; + } + } + + return std::tuple(adjacency_map, current_edges_count); +} + +void write_graph_file( + const adj_list &graph, + const std::string &filename, + const unsigned long &num_vertices, + const unsigned long &num_edges) +{ + std::cout << "Write_graph_file called for file " << filename << std::endl; + // Example: Simulate writing to a file + std::ofstream outfile(filename); + + if (!outfile.is_open()) + { + std::cerr << "Error opening file " << filename << " for writing." << std::endl; + return; + } + + outfile << num_vertices << " " << num_edges << std::endl; + for (unsigned long i = 0; i < num_vertices; ++i) + { + const auto &neighbors = graph.at(i); + outfile << neighbors[0] + 1; + + for (auto it = neighbors.begin() + 1; it != neighbors.end(); ++it) + { + if (*it < num_vertices) + { + outfile << " " << *it + 1; + } + } + outfile << '\n'; + } + outfile.close(); +} + +int main(int argc, char *argv[]) +{ + // Check if the correct number of command-line arguments is provided + if (argc != 4) + { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; // Indicate an error + } + + unsigned long num_vertices; + unsigned int max_degree_uint; + std::string output_filename; + + // Parse the number of vertices + try + { + // Use std::stoul for unsigned long + num_vertices = std::stoul(argv[1]); + // Optional: Add a check for extremely large values if needed + if (num_vertices == 0) + { + std::cerr << "Error: Number of vertices must be greater than 0." << std::endl; + return 1; + } + } + catch (const std::invalid_argument &ia) + { + std::cerr << "Error: Invalid number of vertices provided: " << ia.what() << std::endl; + return 1; + } + catch (const std::out_of_range &oor) + { + std::cerr << "Error: Number of vertices out of range: " << oor.what() << std::endl; + return 1; + } + + // Parse the maximum degree + try + { + // Use std::stoi for int, then cast to unsigned int + int max_degree_int = std::stoi(argv[2]); + if (max_degree_int < 1) + { + std::cerr << "Error: Maximum degree must be at least 1." << std::endl; + return 1; + } + // Check if the integer value fits into an unsigned int + if (max_degree_int > 10) + { + std::cerr << "Error: Maximum degree value too large for unsigned int." << std::endl; + return 1; + } + max_degree_uint = static_cast(max_degree_int); + } + catch (const std::invalid_argument &ia) + { + std::cerr << "Error: Invalid maximum degree provided: " << ia.what() << std::endl; + return 1; + } + catch (const std::out_of_range &oor) + { + std::cerr << "Error: Maximum degree out of range: " << oor.what() << std::endl; + return 1; + } + + // The third argument is the filename + output_filename = argv[3]; + + // Cast max_degree_uint to int for the generate_graph_map function signature + int max_degree_int_for_func = static_cast(max_degree_uint); + + const auto graph_adj_list = generate_graph(num_vertices, max_degree_int_for_func); + + // Print the generated adjacency list + std::cout << "Generated Graph Adjacency List:" << std::endl; + // Iterate through vertices 0 to num_vertices-1 for consistent output order + + const auto &adj_list = std::get<0>(graph_adj_list); + const auto &num_edges = std::get<1>(graph_adj_list); + write_graph_file(adj_list, output_filename, num_vertices, num_edges); + return 0; +} \ No newline at end of file diff --git a/experiments/Synthetic-Billion/Generator/Makefile b/experiments/Synthetic-Billion/Generator/Makefile new file mode 100644 index 0000000..e69de29 diff --git a/experiments/Synthetic-Billion/README.md b/experiments/Synthetic-Billion/README.md new file mode 100644 index 0000000..8a43815 --- /dev/null +++ b/experiments/Synthetic-Billion/README.md @@ -0,0 +1,43 @@ +## Billion Vertex Graphs + +As DGraph allows us to learn extremely large graphs, we push the size of the graphs beyond to train with full graph GNN training. We generate a synthetic graphs with 1 billion vertices. + +## Data Generation + +### Building the Graph Generator +We provide a fast graph generator to generate large graphs. The generator generates a graph with a given number of vertices and a maximum degree. The generator just requires a `GCC>10.3`. Build the generator in the `Generator` directory +```bash +cd Generator +make +``` + +### Generating the Graph +The generator takes the number of vertices and the maximum degree as input, and outputs a text file in the METIS graph format. Run the following command to generate a graph with 1 billion vertices with a maximum degree of 5: + +```bash +./Generator/graph_generator 1000000000 5 1B5D.graph +``` + +This will generate an undirected graph with 1 billion vertices and a maximum degree of 5. The graph will be saved in the file `1B5D.graph`. The generator will take a few minutes to run and require `~150GB` of memory. + +The graph will be generated in the METIS format, which is a simple text format that describes the graph. The first line of the file contains the number of vertices and edges. The i-th line of the file contains the neighbors of the i-th vertex. + +### Partition the graph + +We assume a there is a working `METIS` installation with flags `i64=1` and `r64=1`. `Parametis` may be useful as well. + +To partition the graph in to `` partitions, run the following command: +```bash +gpmetis 1B5D.graph +``` +This will generate a file `1B5D.graph.part.` which contains the partitioning of the graph. The i-th line of the file contains the partition id of the i-th vertex. The partition ids are 0-indexed. This also requires `~150GB` of memory (with the flag `-ondisk`). + +### Preprocess for DGraph + +To finish the graph generation and make the data ready for DGraph, we take the graph file and partition file and run the following command: +```bash +python preprocess.py --g --p --np +``` + +The script will generate the necessary files for DGraph to run a distributed training partitioned in `` partitions. + diff --git a/experiments/Synthetic-Billion/preprocess.py b/experiments/Synthetic-Billion/preprocess.py new file mode 100644 index 0000000..39d87d6 --- /dev/null +++ b/experiments/Synthetic-Billion/preprocess.py @@ -0,0 +1,153 @@ +import os +import torch +import argparse +import numpy as np +from tqdm import tqdm +import concurrent.futures + +args = argparse.ArgumentParser() +args.add_argument("--g", type=str, default="default.graph", help="graph file name") +args.add_argument( + "--p", type=str, default="partition.graph.N", help="partition file name" +) +args.add_argument("--np", type=int, default=8, help="number of partitions") + + +def safe_mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +args = args.parse_args() +graph_file = args.g +partition_file = args.p +num_partitions = args.np + + +def reorder_vertices(partition_file): + """ + Reorder vertices in the partition file to ensure that the vertex IDs are continuous. + """ + vertex_rank_placement = [] + with open(partition_file, "r") as f: + for line in tqdm(f): + line = line.strip() + vertex_rank_placement.append(int(line)) + # This sorts the vertices so that the vertex IDs are continuous in each + # partition. + vertex_rank_placement = torch.from_numpy(np.array(vertex_rank_placement)) + sorted_rank_placement, sorted_indices = torch.sort(vertex_rank_placement) + + # We need the inverse mapping so we can take the COO list and map it to the new vertex IDs. + # This is the inverse mapping from the sorted indices to the original indices. + + _, reverse_maps = torch.sort(sorted_indices) + + return sorted_rank_placement, sorted_indices, reverse_maps + + +def process_chunk(process_local_adj_list, start_index, end_index, reverse_map): + local_coo_list = [] + for local_idx, line in enumerate(process_local_adj_list): + glocal_idx = start_index + local_idx + src_vertex = reverse_map[glocal_idx].item() + for dst_vertex in line.split(): + dst_vertex = reverse_map[int(dst_vertex) - 1].item() + local_coo_list.append((src_vertex, dst_vertex)) + return local_coo_list + + +def reorder_edge_list(graph_file, reverse_map): + + # This file is quite large usually, so try to do multiprocessing + adj_list = [] + + with open(graph_file, "r") as f: + first_line = f.readline() + num_vertices, num_edges = map(int, first_line.strip().split()) + assert num_vertices == len(reverse_map) + + for i in tqdm(range(num_vertices)): + line = f.readline() + line = line.strip() + adj_list.append(line) + + num_cpus = os.cpu_count() or 1 + # num_workers = max(num_cpus - 2, 1) + num_workers = 8 + chunk_size = max(1, num_vertices // num_workers) + + worker_results = [] + + with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: + for i in range(0, num_vertices, chunk_size): + start_index = i + end_index = min(i + chunk_size, num_vertices) + worker_results.append( + executor.submit( + process_chunk, + adj_list[start_index:end_index], + start_index, + end_index, + reverse_map, + ) + ) + + for future in tqdm(concurrent.futures.as_completed(worker_results)): + result = future.result() + if result is not None: + worker_results.extend(result) + else: + print("Worker failed to process chunk.") + + coo_list = torch.tensor( + np.array(worker_results), dtype=torch.int64 + ) # shape: (num_edges, 2) + assert coo_list.shape[1] == 2 + assert coo_list.shape[0] == 2 * num_edges + + # src_vertex = reverse_map[i] + # for dst_vertex in line.split(): + + # dst_vertex = reverse_map[int(dst_vertex) - 1] + # coo_list[edge_counter, 0] = src_vertex + # coo_list[edge_counter, 1] = dst_vertex + # edge_counter += 1 + # assert edge_counter == 2 * num_edges + + return coo_list + + +def main(): + # Reorder vertices in the partition file + # safe_mkdir("processed") + # sorted_rank_placement, sorted_indices, reverse_maps = reorder_vertices( + # partition_file + # ) + # torch.save( + # sorted_rank_placement, + # os.path.join("processed", f"sorted_rank_placement_{num_partitions}.pt"), + # ) + + # torch.save( + # sorted_indices, os.path.join("processed", f"forward_map_{num_partitions}.pt") + # ) + # torch.save( + # reverse_maps, os.path.join("processed", f"reverse_map_{num_partitions}.pt") + # ) + sorted_rank_placement = None + sorted_indices = None + + reverse_maps = torch.load( + os.path.join("processed", f"reverse_map_{num_partitions}.pt") + ) + # Reorder the edge list in the graph file + coo_list = reorder_edge_list(graph_file, reverse_maps) + torch.save( + coo_list, + os.path.join("processed", f"edges_{num_partitions}.pt"), + ) + + +if __name__ == "__main__": + main() From 2ea4bcab8f91ed70b10b79a21433dc4ad204aba0 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 1 May 2025 15:06:55 -0700 Subject: [PATCH 2/8] Update local kernel --- experiments/Synthetic-Billion/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/experiments/Synthetic-Billion/README.md b/experiments/Synthetic-Billion/README.md index 8a43815..61584ea 100644 --- a/experiments/Synthetic-Billion/README.md +++ b/experiments/Synthetic-Billion/README.md @@ -41,3 +41,4 @@ python preprocess.py --g --p --np The script will generate the necessary files for DGraph to run a distributed training partitioned in `` partitions. + From 82b48e9fc18595e206ce10fd1d9f97a88d2a835d Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 29 May 2025 00:57:18 -0700 Subject: [PATCH 3/8] Optimized cache generation with scatter call and inverse_indices Add standalone file to generation cache to asynchronously generate and save caches - Update run code to load pre-saved cache files --- DGraph/distributed/RankLocalOps.py | 6 ++++++ experiments/OGB/GenerateCache.py | 3 ++- experiments/OGB/main.py | 15 +++++++-------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index c4b6de0..8ed9f2d 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -140,7 +140,13 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices +<<<<<<< HEAD unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) +======= + unique_rank_mapping = torch.zeros_like( + unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device + ) +>>>>>>> a18faef (Optimized cache generation with scatter call and inverse_indices) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping diff --git a/experiments/OGB/GenerateCache.py b/experiments/OGB/GenerateCache.py index a16e795..d35ca02 100644 --- a/experiments/OGB/GenerateCache.py +++ b/experiments/OGB/GenerateCache.py @@ -30,6 +30,7 @@ "ogbn-arxiv": "arxiv", "ogbn-products": "products", "ogbn-papers100M": "papers100M", + "ogbn-proteins": "proteins", } @@ -85,7 +86,7 @@ def generate_cache_file( def main(dset: str, world_size: int, node_rank_placement_file: str): - assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"] + assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M", "ogbn-proteins"] assert world_size > 0 assert os.path.exists( diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index 3ccd58b..113159a 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -157,7 +157,7 @@ def _run_experiment( # This says where the edges are located edge_placement = rank_mappings[0] - + cache_prefix = f"cache/{dset_name}" scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt" gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt" @@ -187,8 +187,8 @@ def _run_experiment( world_size, ) with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: - torch.save(gather_cache, f) - + torch.save(gather_cache, f) + if scatter_cache is None: nodes_per_rank = dataset.graph_obj.get_nodes_per_rank() @@ -230,12 +230,11 @@ def _run_experiment( end_time = perf_counter() print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s") - - #with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: + # with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: # torch.save(gather_cache, f) - #with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: + # with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: # torch.save(scatter_cache, f) - #print(f"Rank: {rank} Cache Generated") + # print(f"Rank: {rank} Cache Generated") training_times = [] for i in range(epochs): @@ -391,7 +390,7 @@ def main( use_cache=use_cache, num_classes=num_classes, dset_name=dset_name, - in_dim=in_dims[dset_name] + in_dim=in_dims[dset_name], ) training_trajectores[i] = training_traj validation_trajectores[i] = val_traj From cce5c0efa3979d36409be7797cdb68987f7f72fd Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Wed, 23 Jul 2025 13:21:05 -0700 Subject: [PATCH 4/8] Modified graph cast and benchmark code to log performance benchmarks --- experiments/Benchmarks/TestNCCL.py | 160 ++++++++++-------- experiments/Benchmarks/generate_plots.py | 2 +- experiments/GraphCast/README.md | 5 + .../GraphCast/data_utils/graphcast_graph.py | 4 +- experiments/GraphCast/dataset.py | 1 + experiments/GraphCast/layers.py | 2 + experiments/GraphCast/model.py | 30 ++-- experiments/GraphCast/train_graphcast.py | 9 +- experiments/OGB/main.py | 6 + experiments/OGB/preprocess.py | 52 +++++- experiments/OGB/utils.py | 2 + 11 files changed, 181 insertions(+), 92 deletions(-) diff --git a/experiments/Benchmarks/TestNCCL.py b/experiments/Benchmarks/TestNCCL.py index 2a364af..e7848cd 100644 --- a/experiments/Benchmarks/TestNCCL.py +++ b/experiments/Benchmarks/TestNCCL.py @@ -169,7 +169,7 @@ def run_scatter_benchmark( def main(): parser = argparse.ArgumentParser() - parser.add_argument("--message_size", type=int, default=128) + parser.add_argument("--message_size", type=int, default=2) parser.add_argument("--benchmark_cache", action="store_true") parser.add_argument("--num_iters", type=int, default=1000) parser.add_argument("--log_dir", type=str, default="logs") @@ -196,92 +196,114 @@ def main(): benchmark.print(f"Running NCCL Benchmark on {world_size} ranks") # Built in small message benchmarks, in future we can add more - gather_graph_data = get_nccl_gather_benchmark_data(message_size, world_size, device) - benchmark.print("*" * 50) - benchmark.print("Running Gather Benchmark") - times = run_gather_benchmark(benchmark, num_iters, gather_graph_data, cache=None) - - benchmark.print("Saving Gather Benchmark Times") - - for i in range(world_size): - benchmark.save_np(times, f"{log_dir}/NCCL_gather_times_{i}.npy", rank_to_save=i) + for i in range(1, 20): + message_size *= 2 + benchmark.print("*" * 50) + benchmark.print(f"Running NCCL Benchmark for message size {message_size}") + gather_graph_data = get_nccl_gather_benchmark_data( + message_size, world_size, device + ) + dist.barrier() - benchmark.print("Gather Benchmark Complete") - benchmark.print("*" * 50) + benchmark.print("Running Gather Benchmark") + times = run_gather_benchmark( + benchmark, num_iters, gather_graph_data, cache=None + ) - if benchmark_cache: - edge_placement = gather_graph_data.edge_rank_placement - edge_src_rank = gather_graph_data.edge_src_rank - indices = gather_graph_data.edge_indices + benchmark.print("Saving Gather Benchmark Times") - gather_cache = NCCLGatherCacheGenerator( - indices, - edge_placement.view(-1), - edge_src_rank.view(-1), - 1, - rank, - world_size, + benchmark.save_np( + times, + f"{log_dir}/NCCL_gather_times_message_size_{message_size}" + + f"_world_size_{world_size}.npy", + rank_to_save=0, ) + + benchmark.print("Gather Benchmark Complete") benchmark.print("*" * 50) - benchmark.print("Running Gather Benchmark with Cache") - times = run_gather_benchmark( - benchmark, num_iters, gather_graph_data, cache=gather_cache - ) - benchmark.print("Saving Gather Benchmark with Cache Times") - for i in range(world_size): + if benchmark_cache: + edge_placement = gather_graph_data.edge_rank_placement + edge_src_rank = gather_graph_data.edge_src_rank + indices = gather_graph_data.edge_indices + + gather_cache = NCCLGatherCacheGenerator( + indices, + edge_placement.view(-1), + edge_src_rank.view(-1), + 1, + rank, + world_size, + ) + benchmark.print("*" * 50) + benchmark.print("Running Gather Benchmark with Cache") + times = run_gather_benchmark( + benchmark, num_iters, gather_graph_data, cache=gather_cache + ) + + benchmark.print("Saving Gather Benchmark with Cache Times") benchmark.save_np( - times, f"{log_dir}/NCCL_gather_with_cache_times_{i}.npy", rank_to_save=i + times, + f"{log_dir}/NCCL_gather_with_cache_message_size_{message_size}" + + f"_world_size_{world_size}.npy", + rank_to_save=0, ) - benchmark.print("Gather Benchmark with Cache Complete") - benchmark.print("*" * 50) + benchmark.print("Gather Benchmark with Cache Complete") + benchmark.print("*" * 50) - scatter_graph_data = get_nccl_scatter_benchmark_data( - message_size, world_size, device - ) - benchmark.print("*" * 50) - benchmark.print("Running Scatter Benchmark") - times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data, cache=None) + scatter_graph_data = get_nccl_scatter_benchmark_data( + message_size, world_size, device + ) - benchmark.print("Saving Scatter Benchmark Times") - for i in range(world_size): - benchmark.save_np( - times, f"{log_dir}/NCCL_scatter_times_{i}.npy", rank_to_save=i - ) + benchmark.print("*" * 50) + benchmark.print("Running Scatter Benchmark") + times = run_scatter_benchmark( + benchmark, num_iters, scatter_graph_data, cache=None + ) - benchmark.print("Scatter Benchmark Complete") - benchmark.print("*" * 50) - if benchmark_cache: - edge_placement = scatter_graph_data.edge_rank_placement - edge_dest_rank = scatter_graph_data.edge_dest_rank - indices = scatter_graph_data.edge_indices - - scatter_cache = NCCLScatterCacheGenerator( - indices, - edge_placement.view(-1), - edge_dest_rank.view(-1), - 1, - rank, - world_size, - ) - benchmark.print("*" * 50) - benchmark.print("Running Scatter Benchmark with Cache") - times = run_scatter_benchmark( - benchmark, num_iters, scatter_graph_data, cache=scatter_cache - ) + benchmark.print("Saving Scatter Benchmark Times") - benchmark.print("Saving Scatter Benchmark with Cache Times") - for i in range(world_size): benchmark.save_np( times, - f"{log_dir}/NCCL_scatter_with_cache_times_{i}.npy", - rank_to_save=i, + f"{log_dir}/NCCL_scatter_times_message_size_{message_size}" + + f"_world_size_{world_size}.npy", + rank_to_save=0, ) - benchmark.print("Scatter Benchmark with Cache Complete") - benchmark.print("*" * 50) + benchmark.print("Scatter Benchmark Complete") + benchmark.print("*" * 50) + if benchmark_cache: + edge_placement = scatter_graph_data.edge_rank_placement + edge_dest_rank = scatter_graph_data.edge_dest_rank + indices = scatter_graph_data.edge_indices + + scatter_cache = NCCLScatterCacheGenerator( + indices, + edge_placement.view(-1), + edge_dest_rank.view(-1), + 1, + rank, + world_size, + ) + benchmark.print("*" * 50) + benchmark.print("Running Scatter Benchmark with Cache") + times = run_scatter_benchmark( + benchmark, num_iters, scatter_graph_data, cache=scatter_cache + ) + + benchmark.print("Saving Scatter Benchmark with Cache Times") + + benchmark.save_np( + times, + f"{log_dir}/NCCL_scatter_with_cache_message_size_{message_size}" + + f"_world_size_{world_size}.npy", + rank_to_save=0, + ) + + benchmark.print("Scatter Benchmark with Cache Complete") + benchmark.print("*" * 50) dist.destroy_process_group() diff --git a/experiments/Benchmarks/generate_plots.py b/experiments/Benchmarks/generate_plots.py index ececf05..f280c88 100644 --- a/experiments/Benchmarks/generate_plots.py +++ b/experiments/Benchmarks/generate_plots.py @@ -76,5 +76,5 @@ def generate_cache_comparison_plot(): if __name__ == "__main__": generate_plots("nccl") - generate_plots("nvshmem") + # generate_plots("nvshmem") generate_cache_comparison_plot() diff --git a/experiments/GraphCast/README.md b/experiments/GraphCast/README.md index bb68840..f98acfb 100644 --- a/experiments/GraphCast/README.md +++ b/experiments/GraphCast/README.md @@ -38,3 +38,8 @@ Run with benchmarking with the following command: python main.py --benchmark ``` ***Note: *** The graph requires a large amount of memory so better to do run on the CPU and a machine with a large amount of memory. + +Run with multiple processes per GPU with the following command: +```bash +torchrun-hpc --xargs=--mpibind=off --xargs=--gpu-bind=none train_graphcast.py --is_distributed True --procs_per_gpu 4 +``` \ No newline at end of file diff --git a/experiments/GraphCast/data_utils/graphcast_graph.py b/experiments/GraphCast/data_utils/graphcast_graph.py index 584fbce..f64acaf 100644 --- a/experiments/GraphCast/data_utils/graphcast_graph.py +++ b/experiments/GraphCast/data_utils/graphcast_graph.py @@ -219,7 +219,8 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict): contigous_edge_mapping, renumbered_edges = torch.sort(meshtogrid_edge_placement) src_grid_indices = src_grid_indices[renumbered_edges] - grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat) + grid_vertex_rank_placement = torch.zeros_like(lat_lon_grid_flat[:, 0]) + for i, rank in enumerate(meshtogrid_edge_placement): loc = src_grid_indices[i] grid_vertex_rank_placement[loc] = rank @@ -254,6 +255,7 @@ def get_mesh2grid_graph( ) edge_features, src_mesh_indices, dst_grid_indices = m2g_graph + breakpoint() src_mesh_indices = renumbered_vertices[src_mesh_indices] dst_grid_indices = renumbered_grid[dst_grid_indices] diff --git a/experiments/GraphCast/dataset.py b/experiments/GraphCast/dataset.py index 3e4eb08..5a65b3a 100644 --- a/experiments/GraphCast/dataset.py +++ b/experiments/GraphCast/dataset.py @@ -87,6 +87,7 @@ def __init__( self.lat_lon_grid = torch.stack( torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1 ) + self.graph_cast_graph = DistributedGraphCastGraphGenerator( self.lat_lon_grid, mesh_level=self.mesh_level, diff --git a/experiments/GraphCast/layers.py b/experiments/GraphCast/layers.py index 524987b..1f08196 100644 --- a/experiments/GraphCast/layers.py +++ b/experiments/GraphCast/layers.py @@ -20,6 +20,8 @@ from DGraph.Communicator import Communicator from dist_utils import SingleProcessDummyCommunicator +# class MLPSiLuWithRecompute(nn.Module): + class MeshGraphMLP(nn.Module): """MLP for graph processing""" diff --git a/experiments/GraphCast/model.py b/experiments/GraphCast/model.py index da6459b..703f311 100644 --- a/experiments/GraphCast/model.py +++ b/experiments/GraphCast/model.py @@ -330,7 +330,10 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): ) def forward( - self, input_grid_features: Tensor, static_graph: DistributedGraphCastGraph + self, + input_grid_features: Tensor, + static_graph: DistributedGraphCastGraph, + device: Optional[torch.device] = None, ) -> Tensor: """ Args: @@ -340,18 +343,19 @@ def forward( Returns: (Tensor): The predicted output grid """ - - input_grid_features = input_grid_features.squeeze(0) - input_mesh_features = static_graph.mesh_graph_node_features - mesh2mesh_edge_features = static_graph.mesh_graph_edge_features - grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features - mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features - mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices - mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices - mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices - mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices - grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices - grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices + if device is None: + device = input_grid_features.device + input_grid_features = input_grid_features.squeeze(0).to(device) + input_mesh_features = static_graph.mesh_graph_node_features.to(device) + mesh2mesh_edge_features = static_graph.mesh_graph_edge_features.to(device) + grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features.to(device) + mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features.to(device) + mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices.to(device) + mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices.to(device) + mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices.to(device) + mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices.to(device) + grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices.to(device) + grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices.to(device) out = self.embedder( input_grid_features, diff --git a/experiments/GraphCast/train_graphcast.py b/experiments/GraphCast/train_graphcast.py index 39fea87..be77b91 100644 --- a/experiments/GraphCast/train_graphcast.py +++ b/experiments/GraphCast/train_graphcast.py @@ -60,8 +60,10 @@ def main( comm = Communicator.init_process_group( _communicator, ranks_per_graph=procs_per_graph ) + mesh_graph_placement = torch.load("mesh_vertex_rank_placement_4.pt") else: comm = SingleProcessDummyCommunicator() + mesh_graph_placement = torch.zeros(40962, dtype=torch.int64) if not use_synthetic_data: raise NotImplementedError("Real data is not yet supported yet.") @@ -106,6 +108,7 @@ def main( dataset = SyntheticWeatherDataset( channels=[x for x in range(cfg.data.num_channels_climate)], num_samples_per_year=cfg.data.num_samples_per_year_train, + mesh_vertex_placement=mesh_graph_placement, num_steps=cfg.data.num_history, device=torch.device("cpu"), ) @@ -127,12 +130,12 @@ def main( break_training = False for data in dataloader: - in_data = data["invar"] - ground_truth = data["outvar"] + in_data = data["invar"].to(device) + ground_truth = data["outvar"].to(device) model.train() optimizer.zero_grad() - predicted_grid = model(in_data, static_graph) + predicted_grid = model(in_data, static_graph, device=device) loss = compute_loss(ground_truth, predicted_grid, comm) loss.backward() optimizer.step() diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index 113159a..d520396 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -163,6 +163,7 @@ def _run_experiment( gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt" if os.path.exists(gather_cache_file): + print(f"Rank: {rank} Loading gather cache from {gather_cache_file}") gather_cache = torch.load(gather_cache_file, weights_only=False) if os.path.exists(scatter_cache_file): @@ -379,6 +380,11 @@ def main( validation_trajectores = np.zeros((runs, epochs)) validation_accuracies = np.zeros((runs, epochs)) world_size = comm.get_world_size() + + dist.barrier() + print(f"Running experiment with {world_size} processes on dataset {dataset}") + print(f"Using cache: {use_cache}") + for i in range(runs): log_prefix = f"{log_dir}/{dataset}_{world_size}_cache={use_cache}_run_{i}" training_traj, val_traj, val_accuracy = _run_experiment( diff --git a/experiments/OGB/preprocess.py b/experiments/OGB/preprocess.py index 5fc3cef..857bc76 100644 --- a/experiments/OGB/preprocess.py +++ b/experiments/OGB/preprocess.py @@ -1,4 +1,4 @@ -import metis +# import metis import torch import numpy as np @@ -10,6 +10,7 @@ import networkx as nx import numpy as np from fire import Fire +from tqdm import tqdm def partition_graph(coo_list: np.ndarray, num_ranks: int): @@ -105,13 +106,39 @@ def load_networkx_graph(dname): return G +def topological_sort_graph(coo_list, num_nodes): + in_degree = np.zeros(num_nodes, dtype=int) + adj_list = [[] for _ in range(num_nodes)] + for src, dst in tqdm(coo_list): + in_degree[dst] += 1 + adj_list[src].append(dst) + + queue = [node for node in range(num_nodes) if in_degree[node] == 0] + sorted_nodes = [] + while queue: + node = queue.pop(0) + sorted_nodes.append(node) + for neighbor in adj_list[node]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + if len(sorted_nodes) % 1000 == 0: + print(f"Processed {len(sorted_nodes)} nodes...") + + if len(sorted_nodes) != num_nodes: + raise ValueError("Graph is not a DAG, topological sort failed.") + + return np.array(sorted_nodes, dtype=np.int64) + + def main(dset_name: str): from ogb.nodeproppred import NodePropPredDataset - assert dset_name in ["ogbn-arxiv", "ogbn-products", "ogbn-proteins"] + assert dset_name in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"] is_directed = False - if dset_name == "ogbn-arxiv": + if dset_name == "ogbn-arxiv" or dset_name == "ogbn-papers100M": is_directed = True dataset = NodePropPredDataset( @@ -119,12 +146,27 @@ def main(dset_name: str): ) graph_data, labels = dataset[0] - edge_index = torch.Tensor(graph_data["edge_index"]).long() + edge_index = torch.Tensor(graph_data["edge_index"]).long().clone() + num_nodes = graph_data["num_nodes"] + del graph_data + + print(f"Number of nodes: {num_nodes}") + + print(f"Number of edges: {edge_index.shape[1]}") + + # save_networkx_graph(coo_list, num_nodes, dset_name, directed=is_directed) + coo_list = edge_index.numpy().T + sorted_vertices = topological_sort_graph(coo_list, num_nodes) + + np.save(f"{dset_name}_sorted_vertices.npy", sorted_vertices) - save_networkx_graph(coo_list, num_nodes, dset_name, directed=is_directed) + # print(f"Number of edges in COO format: {coo_list.shape}") + # with open(f"{dset_name}_coo_list.csv", "w") as f: + # for edge in tqdm(coo_list): + # f.write(f"{edge[0]},{edge[1]}\n") if __name__ == "__main__": diff --git a/experiments/OGB/utils.py b/experiments/OGB/utils.py index fef5d8f..30fe867 100644 --- a/experiments/OGB/utils.py +++ b/experiments/OGB/utils.py @@ -54,6 +54,8 @@ def safe_create_dir(directory, rank): def calculate_accuracy(pred, labels): + if len(labels) == 0: + return 0.0 pred = pred.argmax(dim=1) correct = pred.eq(labels).sum().item() if len(labels) > 0: From d4c7306217ac4d9658aeed6e1ceb9cf01e8e2b18 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Sat, 26 Jul 2025 10:12:15 -0700 Subject: [PATCH 5/8] Add barrier around dataset processor for race condition --- DGraph/distributed/nccl/NCCLBackendEngine.py | 5 +++++ experiments/OGB/GCN.py | 18 ++++++++++++++++++ experiments/OGB/main.py | 15 ++++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index b3ea11a..ddc3dc2 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -512,6 +512,11 @@ def __init__(self, ranks_per_graph=-1, *args, **kwargs): if not NCCLBackendEngine._is_initialized: self.init_process_group(ranks_per_graph) + def barrier(self) -> None: + if not dist.is_initialized(): + raise RuntimeError("NCCL backend engine is not initialized") + dist.barrier() + def init_process_group(self, ranks_per_graph=-1, *args, **kwargs): if not dist.is_initialized(): dist.init_process_group(backend="nccl", *args, **kwargs) diff --git a/experiments/OGB/GCN.py b/experiments/OGB/GCN.py index 3de4c78..0927ca1 100644 --- a/experiments/OGB/GCN.py +++ b/experiments/OGB/GCN.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn import torch.distributed as dist +from DGraph.utils.TimingReport import TimingReport class ConvLayer(nn.Module): @@ -54,24 +55,41 @@ def forward( num_local_nodes = node_features.size(1) _src_indices = edge_index[:, 0, :] _dst_indices = edge_index[:, 1, :] + TimingReport.start("pre-processing") _src_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 ) _dst_rank_mappings = torch.cat( [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 ) + TimingReport.stop("pre-processing") + TimingReport.start("Gather_1") x = self.comm.gather( node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache ) + TimingReport.stop("Gather_1") + TimingReport.start("Conv_1") x = self.conv1(x) + TimingReport.stop("Conv_1") + TimingReport.start("Scatter_1") x = self.comm.scatter( x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache ) + TimingReport.stop("Scatter_1") + TimingReport.start("Gather_2") x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache) + TimingReport.stop("Gather_2") + TimingReport.start("Conv_2") x = self.conv2(x) + TimingReport.stop("Conv_2") + TimingReport.start("Scatter_2") x = self.comm.scatter( x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache ) + TimingReport.stop("Scatter_2") + TimingReport.start("Final_FC") x = self.fc(x) + TimingReport.stop("Final_FC") + # x = self.softmax(x) return x diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index d520396..47e356b 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -38,6 +38,8 @@ ) import numpy as np import os +from DGraph.utils.TimingReport import TimingReport +import json class SingleProcessDummyCommunicator(CommunicatorBase): @@ -131,7 +133,6 @@ def _run_experiment( print(f"Rank: {rank} Mapping: {rank_mappings.shape}") print(f"Rank: {rank} Node Features: {node_features.shape}") print(f"Rank: {rank} Edge Indices: {edge_indices.shape}") - comm.barrier() criterion = torch.nn.CrossEntropyLoss() @@ -229,7 +230,9 @@ def _run_experiment( assert rank != rank assert value.shape[0] == scatter_cache.gather_recv_comm_vector end_time = perf_counter() - print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s") + elapsed_time_in_ms = (end_time - start_time) * 1000 + print(f"Rank: {rank} Cache Generation Time: {elapsed_time_in_ms:.4f} ms") + TimingReport.add_time("cache_generation_time", elapsed_time_in_ms) # with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: # torch.save(gather_cache, f) @@ -366,6 +369,7 @@ def main( node_rank_placement_file, weights_only=False ) + TimingReport.init(comm) safe_create_dir(log_dir, comm.get_rank()) training_dataset = DistributedOGBWrapper( f"ogbn-{dataset}", @@ -381,7 +385,7 @@ def main( validation_accuracies = np.zeros((runs, epochs)) world_size = comm.get_world_size() - dist.barrier() + comm.barrier() print(f"Running experiment with {world_size} processes on dataset {dataset}") print(f"Using cache: {use_cache}") @@ -402,6 +406,11 @@ def main( validation_trajectores[i] = val_traj validation_accuracies[i] = val_accuracy + write_experiment_log( + json.dumps(TimingReport._timers), + f"{log_dir}/timing_report_world_size_{world_size}_cache_{use_cache}.json", + comm.get_rank(), + ) visualize_trajectories( training_trajectores, "Training Loss", From 97b5e8d5c787aae43b1ac5b3230c3a001ffa7a54 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 25 Aug 2025 22:46:53 -0700 Subject: [PATCH 6/8] Rebase with local kernels --- DGraph/distributed/RankLocalOps.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 8ed9f2d..86216f6 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -140,13 +140,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices -<<<<<<< HEAD - unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) -======= unique_rank_mapping = torch.zeros_like( unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device ) ->>>>>>> a18faef (Optimized cache generation with scatter call and inverse_indices) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping From c14001e12b13a1b5c95f448270c39842e6b177bc Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Thu, 28 Aug 2025 04:12:34 -0700 Subject: [PATCH 7/8] Add IGB pre-processing and running code --- experiments/Benchmarks/TestNVSHMEM.py | 48 ++-- experiments/IGB/GCN.py | 95 +++++++ experiments/IGB/GenerateCache.py | 158 +++++++++++ experiments/IGB/IGB260MDataset.py | 115 ++++++++ experiments/IGB/main.py | 379 ++++++++++++++++++++++++++ experiments/IGB/utils.py | 64 +++++ experiments/OGB/main.py | 8 +- 7 files changed, 843 insertions(+), 24 deletions(-) create mode 100644 experiments/IGB/GCN.py create mode 100644 experiments/IGB/GenerateCache.py create mode 100644 experiments/IGB/IGB260MDataset.py create mode 100644 experiments/IGB/main.py create mode 100644 experiments/IGB/utils.py diff --git a/experiments/Benchmarks/TestNVSHMEM.py b/experiments/Benchmarks/TestNVSHMEM.py index 7e6989f..9427f53 100644 --- a/experiments/Benchmarks/TestNVSHMEM.py +++ b/experiments/Benchmarks/TestNVSHMEM.py @@ -127,7 +127,7 @@ def run_scatter_benchmark( def main(): parser = argparse.ArgumentParser() - parser.add_argument("--message_size", type=int, default=128) + parser.add_argument("--message_size", type=int, default=2) parser.add_argument("--benchmark_cache", action="store_true") parser.add_argument("--num_iters", type=int, default=1000) parser.add_argument("--log_dir", type=str, default="logs") @@ -160,37 +160,45 @@ def main(): benchmark.print("*" * 50) benchmark.print("Running Gather Benchmark") - gather_graph_data = get_nvshmem_gather_benchmark_data( - message_size, rank, world_size, device - ) - times = run_gather_benchmark(benchmark, num_iters, gather_graph_data) + for i in range(1, 20): + message_size *= 2 + benchmark.print(f"Running NCCL Benchmark for message size {message_size}") + + gather_graph_data = get_nvshmem_gather_benchmark_data( + message_size, rank, world_size, device + ) + times = run_gather_benchmark(benchmark, num_iters, gather_graph_data) - benchmark.print("Saving Gather Benchmark Times") + benchmark.print("Saving Gather Benchmark Times") - for i in range(world_size): benchmark.save_np( - times, f"{log_dir}/NVSHMEM_gather_times_{i}.npy", rank_to_save=i + times, + f"{log_dir}/NVSHMEM_gather_times_message_size_{message_size}" + + f"_with_world_size_{world_size}.npy", + rank_to_save=0, ) - benchmark.print("Gather Benchmark Complete") - benchmark.print("*" * 50) + benchmark.print("Gather Benchmark Complete") + benchmark.print("*" * 50) - scatter_graph_data = get_nvshmem_scatter_benchmark_data( - message_size, rank, world_size, device - ) + scatter_graph_data = get_nvshmem_scatter_benchmark_data( + message_size, rank, world_size, device + ) - benchmark.print("Running Scatter Benchmark") - times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data) + benchmark.print("Running Scatter Benchmark") + times = run_scatter_benchmark(benchmark, num_iters, scatter_graph_data) - benchmark.print("Saving Scatter Benchmark Times") + benchmark.print("Saving Scatter Benchmark Times") - for i in range(world_size): benchmark.save_np( - times, f"{log_dir}/NVSHMEM_scatter_times_{i}.npy", rank_to_save=i + times, + f"{log_dir}/NVSHMEM_scatter_times_message_size_{message_size}" + + f"_with_world_size_{world_size}.npy", + rank_to_save=0, ) - benchmark.print("Scatter Benchmark Complete") - benchmark.print("*" * 50) + benchmark.print("Scatter Benchmark Complete") + benchmark.print("*" * 50) if __name__ == "__main__": diff --git a/experiments/IGB/GCN.py b/experiments/IGB/GCN.py new file mode 100644 index 0000000..b61af13 --- /dev/null +++ b/experiments/IGB/GCN.py @@ -0,0 +1,95 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import torch +import torch.nn as nn +import torch.distributed as dist +from DGraph.utils.TimingReport import TimingReport + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvLayer, self).__init__() + self.conv = nn.Linear(in_channels, out_channels) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return x + + +class CommAwareGCN(nn.Module): + """ + Least interesting GNN model to test distributed training + but good enough for the purpose of testing. + """ + + def __init__(self, in_channels, hidden_dims, num_classes, comm): + super(CommAwareGCN, self).__init__() + + self.conv1 = ConvLayer(in_channels, hidden_dims) + self.conv2 = ConvLayer(hidden_dims, hidden_dims) + self.fc = nn.Linear(hidden_dims, num_classes) + self.softmax = nn.Softmax(dim=1) + self.comm = comm + + def forward( + self, + node_features, + edge_index, + rank_mapping, + gather_cache=None, + scatter_cache=None, + ): + num_local_nodes = node_features.size(1) + _src_indices = edge_index[:, 0, :] + _dst_indices = edge_index[:, 1, :] + TimingReport.start("pre-processing") + _src_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 + ) + _dst_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 + ) + TimingReport.stop("pre-processing") + TimingReport.start("Gather_1") + x = self.comm.gather( + node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache + ) + TimingReport.stop("Gather_1") + TimingReport.start("Conv_1") + x = self.conv1(x) + TimingReport.stop("Conv_1") + TimingReport.start("Scatter_1") + x = self.comm.scatter( + x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache + ) + TimingReport.stop("Scatter_1") + TimingReport.start("Gather_2") + x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache) + TimingReport.stop("Gather_2") + TimingReport.start("Conv_2") + x = self.conv2(x) + TimingReport.stop("Conv_2") + TimingReport.start("Scatter_2") + x = self.comm.scatter( + x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache + ) + TimingReport.stop("Scatter_2") + TimingReport.start("Final_FC") + x = self.fc(x) + TimingReport.stop("Final_FC") + + # x = self.softmax(x) + return x diff --git a/experiments/IGB/GenerateCache.py b/experiments/IGB/GenerateCache.py new file mode 100644 index 0000000..562fedf --- /dev/null +++ b/experiments/IGB/GenerateCache.py @@ -0,0 +1,158 @@ +# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +from DGraph.data.ogbn_datasets import process_homogenous_data +from IGB260MDataset import DistributedIGBWrapper +from fire import Fire +import os +import torch +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) +from time import perf_counter +from tqdm import tqdm +from multiprocessing import get_context + + +cache_prefix = { + "ogbn-arxiv": "arxiv", + "ogbn-products": "products", + "ogbn-papers100M": "papers100M", + "ogbn-proteins": "proteins", +} + + +def generate_cache_file( + dist_graph, + src_indices, + dst_indices, + edge_placement, + edge_src_placement, + edge_dest_placement, + cache_prefix_str: str, + rank: int, + world_size: int, +): + print(f"Generating cache for rank {rank}...") + local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0) + num_input_rows = local_node_features.size(1) + + print( + f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}" + ) + gather_cache = NCCLGatherCacheGenerator( + dst_indices, + edge_placement, + edge_dest_placement, + num_input_rows, + rank, + world_size, + ) + + nodes_per_rank = dist_graph.get_nodes_per_rank() + nodes_per_rank = int(nodes_per_rank[rank].item()) + + scatter_cache = NCCLScatterCacheGenerator( + src_indices, + edge_placement, + edge_src_placement, + nodes_per_rank, + rank, + world_size, + ) + print(f"Rank {rank} completed cache generation") + with open( + f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb" + ) as f: + torch.save(gather_cache, f) + + with open( + f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb" + ) as f: + torch.save(scatter_cache, f) + return 0 + + +class DummyComm: + def __init__(self, world_size: int): + self.world_size = world_size + self.rank = 0 + + +def main(root, world_size: int, node_rank_placement_file=None): + assert world_size > 0 + + dataset = DistributedIGBWrapper( + root=root, + comm=DummyComm(world_size), + node_rank_placement=node_rank_placement_file, + sim_node_features=True, + ) + + num_edges = dataset.num_edges + print(num_edges) + + dist_graph = dataset.graph_obj + + edge_indices = dist_graph.get_global_edge_indices() + rank_mappings = dist_graph.get_global_rank_mappings() + + print("Edge indices shape:", edge_indices.shape) + print("Rank mappings shape:", rank_mappings.shape) + + edge_indices = edge_indices.unsqueeze(0) + src_indices = edge_indices[:, 0, :] + dst_indices = edge_indices[:, 1, :] + + edge_placement = rank_mappings[0] + edge_src_placement = rank_mappings[0] + edge_dest_placement = rank_mappings[1] + + start_time = perf_counter() + cache_prefix_str = f"cache/IGB" + os.makedirs("cache", exist_ok=True) + os.makedirs("cache/IGB", exist_ok=True) + + with get_context("spawn").Pool(min(world_size, 8)) as pool: + args = [ + ( + dist_graph, + src_indices, + dst_indices, + edge_placement, + edge_src_placement, + edge_dest_placement, + cache_prefix_str, + rank, + world_size, + ) + for rank in range(world_size) + ] + + out = pool.starmap(generate_cache_file, args) + + end_time = perf_counter() + print(f"Cache generation time: {end_time - start_time:.4f} seconds") + print("Cache files generated successfully.") + print( + f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_.pt" + ) + print( + f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_.pt" + ) + + +if __name__ == "__main__": + Fire(main) diff --git a/experiments/IGB/IGB260MDataset.py b/experiments/IGB/IGB260MDataset.py new file mode 100644 index 0000000..a8c4d6f --- /dev/null +++ b/experiments/IGB/IGB260MDataset.py @@ -0,0 +1,115 @@ +import torch +import numpy as np +import os.path as osp +from DGraph.CommunicatorBase import CommunicatorBase +from DGraph.data.ogbn_datasets import process_homogenous_data + + +def assign_node_rank(num_nodes, world_size): + _div = num_nodes // world_size + _mod = num_nodes % world_size + arr = np.arange(world_size).repeat(_div) + if _mod > 0: + arr = np.concatenate((arr, np.arange(_mod))) + np.random.shuffle(arr) + return torch.from_numpy(arr).long() + + +class DistributedIGBWrapper: + def __init__( + self, + root, + comm, + node_rank_placement=None, + sim_node_features=True, + num_features=1, + ): + self.root = root + self.comm_object = comm + self.rank = comm.get_rank() + self.world_size = comm.get_world_size() + self.num_features = num_features + self.num_nodes = 227130858 + self.num_edges = 3727095830 + self.num_classes = 19 + self.sim_node_features = sim_node_features + if node_rank_placement is None: + node_rank_placement = assign_node_rank(self.num_nodes, self.world_size) + self.load_graph_data(node_rank_placement) + + def load_graph_data(self, node_rank_placement): + processed_dir = osp.join(self.root, "processed") + edge_dir = osp.join(processed_dir, "paper__cites__paper") + node_features_dir = osp.join(processed_dir, "paper") + edges = np.load(osp.join(edge_dir, "edge_index.npy"), mmap_mode="r") + + graph_data = {"edge_index": edges} + + if self.sim_node_features: + node_features = torch.randn( + (self.num_nodes, self.num_features), dtype=torch.float32 + ) + else: + node_features = np.memmap( + osp.join(node_features_dir, "node_feat.npy"), + mode="r", + dtype="float32", + shape=(self.num_nodes, 1024), + ) + self.num_features = 1024 + + graph_data["node_feat"] = node_features + labels = np.memmap( + osp.join(node_features_dir, "node_label_19.npy"), + mode="r", + dtype="float32", + ) + + n_train = int(self.num_nodes * 0.6) + n_val = int(self.num_nodes * 0.2) + + train_mask = torch.zeros(self.num_nodes, dtype=torch.bool) + val_mask = torch.zeros(self.num_nodes, dtype=torch.bool) + test_mask = torch.zeros(self.num_nodes, dtype=torch.bool) + + train_mask[:n_train] = True + val_mask[n_train : n_train + n_val] = True + test_mask[n_train + n_val : self.num_nodes] = True + + split_idx = { + "train": train_mask, + "valid": val_mask, + "test": test_mask, + } + graph_obj = process_homogenous_data( + graph_data, + labels, + 0, + self.world_size, + split_idx, + node_rank_placement, + ) + self.graph_obj = graph_obj + + def __len__(self) -> int: + return 1 + + def __getitem__(self, idx: int): + rank = self.comm_object.get_rank() + local_node_features = self.graph_obj.get_local_node_features(rank=rank) + labels = self.graph_obj.get_local_labels(rank=rank) + + # TODO: Move this to a backend-specific collator in the future + if self.comm_object.backend == "nccl": + # Return Graph object with Rank placement data + + # NOTE: Two-sided comm needs all the edge indices not the local ones + edge_indices = self.graph_obj.get_global_edge_indices() + rank_mappings = self.graph_obj.get_global_rank_mappings() + else: + # One-sided communication, no need for rank placement data + + edge_indices = self.graph_obj.get_local_edge_indices(rank=rank) + rank_mappings = self.graph_obj.get_local_rank_mappings(rank=rank) + + return local_node_features, edge_indices, rank_mappings, labels diff --git a/experiments/IGB/main.py b/experiments/IGB/main.py new file mode 100644 index 0000000..b1570cd --- /dev/null +++ b/experiments/IGB/main.py @@ -0,0 +1,379 @@ +# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import sys +from time import perf_counter +from typing import Optional +from IGB260MDataset import DistributedIGBWrapper +from DGraph.Communicator import CommunicatorBase, Communicator + +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) +import fire +import torch +import torch.optim as optim +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from GCN import CommAwareGCN as GCN +from utils import ( + dist_print_ephemeral, + make_experiment_log, + write_experiment_log, + cleanup, + visualize_trajectories, + safe_create_dir, + calculate_accuracy, +) +import numpy as np +import os +from DGraph.utils.TimingReport import TimingReport +import json + + +def _run_experiment( + dataset, + comm, + lr: float, + epochs: int, + log_prefix: str, + in_dim: int = 128, + hidden_dims: int = 128, + num_classes: int = 40, + use_cache: bool = False, + dset_name: str = "arxiv", +): + local_rank = comm.get_rank() % torch.cuda.device_count() + print(f"Rank: {local_rank} Local Rank: {local_rank}") + torch.cuda.set_device(local_rank) + device = torch.cuda.current_device() + model = GCN( + in_channels=in_dim, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm + ) + rank = comm.get_rank() + model = model.to(device) + + model = ( + DDP(model, device_ids=[local_rank], output_device=local_rank) + if comm.get_world_size() > 1 + else model + ) + optimizer = optim.Adam(model.parameters(), lr=lr) + + stream = torch.cuda.Stream() + + node_features, edge_indices, rank_mappings, labels = dataset[0] + + node_features = node_features.to(device).unsqueeze(0) + edge_indices = edge_indices.to(device).unsqueeze(0) + labels = labels.to(device).unsqueeze(0) + rank_mappings = rank_mappings + + if rank == 0: + print("*" * 80) + for i in range(comm.get_world_size()): + if i == rank: + print(f"Rank: {rank} Mapping: {rank_mappings.shape}") + print(f"Rank: {rank} Node Features: {node_features.shape}") + print(f"Rank: {rank} Edge Indices: {edge_indices.shape}") + comm.barrier() + criterion = torch.nn.CrossEntropyLoss() + + train_mask = dataset.graph_obj.get_local_mask("train", rank) + validation_mask = dataset.graph_obj.get_local_mask("val", rank) + training_loss_scores = [] + validation_loss_scores = [] + validation_accuracy_scores = [] + + world_size = comm.get_world_size() + + print(f"Rank: {rank} training_mask: {train_mask.shape}") + print(f"Rank: {rank} validation_mask: {validation_mask.shape}") + + gather_cache = None + scatter_cache = None + + if use_cache: + print(f"Rank: {rank} Using Cache. Generating Cache") + start_time = perf_counter() + src_indices = edge_indices[:, 0, :] + dst_indices = edge_indices[:, 1, :] + + # This says where the edges are located + edge_placement = rank_mappings[0] + + cache_prefix = f"cache/{dset_name}" + scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt" + gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt" + + if os.path.exists(gather_cache_file): + print(f"Rank: {rank} Loading gather cache from {gather_cache_file}") + gather_cache = torch.load(gather_cache_file, weights_only=False) + + if os.path.exists(scatter_cache_file): + scatter_cache = torch.load(scatter_cache_file, weight_only=False) + + # These say where the source and destination nodes are located + edge_src_placement = rank_mappings[ + 0 + ] # Redundant but making explicit for clarity + edge_dest_placement = rank_mappings[1] + + num_input_rows = node_features.size(1) + local_num_edges = (edge_placement == rank).sum().item() + + if gather_cache is None: + gather_cache = NCCLGatherCacheGenerator( + dst_indices, + edge_placement, + edge_dest_placement, + num_input_rows, + rank, + world_size, + ) + with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: + torch.save(gather_cache, f) + + if scatter_cache is None: + nodes_per_rank = dataset.graph_obj.get_nodes_per_rank() + + scatter_cache = NCCLScatterCacheGenerator( + src_indices, + edge_placement, + edge_src_placement, + nodes_per_rank[rank], + rank, + world_size, + ) + with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: + torch.save(scatter_cache, f) + + # Sanity checks for the cache + for key, value in gather_cache.gather_send_local_placement.items(): + assert value.max().item() < num_input_rows + assert key < world_size + assert key != rank + assert value.shape[0] == gather_cache.gather_send_comm_vector[key] + + for key, value in gather_cache.gather_recv_local_placement.items(): + assert value.max().item() < local_num_edges + assert key < world_size + assert key != rank + assert value.shape[0] == gather_cache.gather_recv_comm_vector[key] + + for rank, value in scatter_cache.gather_send_local_placement.items(): + assert value.max().item() < local_num_edges + assert rank < world_size + assert rank != rank + assert value.shape[0] == scatter_cache.gather_send_comm_vector + + for rank, value in scatter_cache.gather_recv_local_placement.items(): + assert value.max().item() < num_input_rows + assert rank < world_size + assert rank != rank + assert value.shape[0] == scatter_cache.gather_recv_comm_vector + end_time = perf_counter() + elapsed_time_in_ms = (end_time - start_time) * 1000 + print(f"Rank: {rank} Cache Generation Time: {elapsed_time_in_ms:.4f} ms") + TimingReport.add_time("cache_generation_time", elapsed_time_in_ms) + + # with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: + # torch.save(gather_cache, f) + # with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: + # torch.save(scatter_cache, f) + # print(f"Rank: {rank} Cache Generated") + + training_times = [] + for i in range(epochs): + comm.barrier() + torch.cuda.synchronize() + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + start_time.record(stream) + optimizer.zero_grad() + _output = model( + node_features, edge_indices, rank_mappings, gather_cache, scatter_cache + ) + # Must flatten along the batch dimension for the loss function + output = _output[:, train_mask].view(-1, num_classes) + gt = labels[:, train_mask].view(-1) + loss = criterion(output, gt) + loss.backward() + dist_print_ephemeral(f"Epoch {i} \t Loss: {loss.item()}", rank) + optimizer.step() + + comm.barrier() + end_time.record(stream) + torch.cuda.synchronize() + training_times.append(start_time.elapsed_time(end_time)) + training_loss_scores.append(loss.item()) + write_experiment_log(str(loss.item()), f"{log_prefix}_training_loss.log", rank) + + model.eval() + with torch.no_grad(): + validation_preds = _output[:, validation_mask].view(-1, num_classes) + label_validation = labels[:, validation_mask].view(-1) + validation_score = criterion( + validation_preds, + label_validation, + ) + write_experiment_log( + str(validation_score.item()), f"{log_prefix}_validation_loss.log", rank + ) + + validation_loss_scores.append(validation_score.item()) + + val_pred = torch.log_softmax(validation_preds, dim=1) + accuracy = calculate_accuracy(val_pred, label_validation) + validation_accuracy_scores.append(accuracy) + write_experiment_log( + f"Validation Accuracy: {accuracy:.2f}", + f"{log_prefix}_validation_accuracy.log", + rank, + ) + model.train() + + torch.cuda.synchronize() + + model.eval() + + with torch.no_grad(): + test_idx = dataset.graph_obj.get_local_mask("test", rank) + test_labels = labels[:, test_idx].view(-1) + test_preds = model(node_features, edge_indices, rank_mappings)[:, test_idx] + test_preds = test_preds.view(-1, num_classes) + test_loss = criterion(test_preds, test_labels) + test_preds = torch.log_softmax(test_preds, dim=1) + test_accuracy = calculate_accuracy(test_preds, test_labels) + test_log_file = f"{log_prefix}_test_results.log" + write_experiment_log( + "loss,accuracy", + test_log_file, + rank, + ) + write_experiment_log(f"{test_loss.item()},{test_accuracy}", test_log_file, rank) + + make_experiment_log(f"{log_prefix}_training_times.log", rank) + make_experiment_log(f"{log_prefix}_runtime_experiment.log", rank) + + for times in training_times: + write_experiment_log(str(times), f"{log_prefix}_training_times.log", rank) + + average_time = np.mean(training_times[1:]) + log_str = f"Average time per epoch: {average_time:.4f} ms" + write_experiment_log(log_str, f"{log_prefix}_runtime_experiment.log", rank) + + return ( + np.array(training_loss_scores), + np.array(validation_loss_scores), + np.array(validation_accuracy_scores), + ) + + +def main( + root_dir: str = ".", + backend: str = "nccl", + epochs: int = 3, + lr: float = 0.001, + runs: int = 1, + log_dir: str = "logs", + node_rank_placement_file: Optional[str] = None, + use_cache: bool = False, +): + _communicator = backend.lower() + dset_name = "igb260m" + assert _communicator.lower() in [ + "nccl", + "nvshmem", + "mpi", + ], "Invalid backend" + + node_rank_placement = None + + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + comm = Communicator.init_process_group(_communicator) + + # Must pass the node rank placement file the first time + if node_rank_placement_file is not None: + assert os.path.exists( + node_rank_placement_file + ), "Node rank placement file not found" + node_rank_placement = torch.load(node_rank_placement_file, weights_only=False) + + TimingReport.init(comm) + safe_create_dir(log_dir, comm.get_rank()) + training_dataset = DistributedIGBWrapper( + root=root_dir, + comm=comm, + node_rank_placement=node_rank_placement, + ) + + num_classes = training_dataset.num_classes + in_dims = training_dataset.num_features + + training_trajectores = np.zeros((runs, epochs)) + validation_trajectores = np.zeros((runs, epochs)) + validation_accuracies = np.zeros((runs, epochs)) + world_size = comm.get_world_size() + + comm.barrier() + + print(f"Using cache: {use_cache}") + + for i in range(runs): + log_prefix = f"{log_dir}/{dset_name}_{world_size}_cache={use_cache}_run_{i}" + training_traj, val_traj, val_accuracy = _run_experiment( + training_dataset, + comm, + lr, + epochs, + log_prefix, + use_cache=use_cache, + num_classes=num_classes, + dset_name=dset_name, + in_dim=in_dims, + ) + training_trajectores[i] = training_traj + validation_trajectores[i] = val_traj + validation_accuracies[i] = val_accuracy + + write_experiment_log( + json.dumps(TimingReport._timers), + f"{log_dir}/{dset_name}_timing_report_world_size_{world_size}_cache_{use_cache}.json", + comm.get_rank(), + ) + visualize_trajectories( + training_trajectores, + "Training Loss", + f"{log_dir}/{dset_name}_training_loss.png", + comm.get_rank(), + ) + visualize_trajectories( + validation_trajectores, + "Validation Loss", + f"{log_dir}/{dset_name}_validation_loss.png", + comm.get_rank(), + ) + visualize_trajectories( + validation_accuracies, + "Validation Accuracy", + f"{log_dir}/{dset_name}_validation_accuracy.png", + comm.get_rank(), + ) + cleanup() + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/experiments/IGB/utils.py b/experiments/IGB/utils.py new file mode 100644 index 0000000..30fe867 --- /dev/null +++ b/experiments/IGB/utils.py @@ -0,0 +1,64 @@ +import torch.distributed as dist +import numpy as np +import matplotlib.pyplot as plt +import os + + +def cleanup(): + if dist.is_initialized(): + dist.destroy_process_group() + + +def make_experiment_log(fname, rank): + if rank == 0: + with open(fname, "w") as f: + f.write("") + + +def write_experiment_log(log: str, fname: str, rank: int): + if rank == 0: + with open(fname, "a") as f: + f.write(log + "\n") + + +def dist_print_ephemeral( + msg, + rank, +): + if rank == 0: + print(msg, end="\r") + + +def visualize_trajectories(trajectories, title, figsave, rank): + if rank != 0: + return + mean = np.mean(trajectories, axis=0) + std = np.std(trajectories, axis=0) + x = np.arange(len(mean)) + + fig, ax = plt.subplots() + + ax.plot(x, mean, "-") + ax.fill_between(x, mean - std, mean + std, alpha=0.2) + + ax.set_xlabel("Epochs") + ax.set_ylabel("Loss") + ax.set_title(title) + fig.savefig(figsave) + + +def safe_create_dir(directory, rank): + if rank == 0: + if not os.path.exists(directory): + os.makedirs(directory) + + +def calculate_accuracy(pred, labels): + if len(labels) == 0: + return 0.0 + pred = pred.argmax(dim=1) + correct = pred.eq(labels).sum().item() + if len(labels) > 0: + return correct / len(labels) * 100 + else: + return 0.0 diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index 47e356b..5c4d802 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -408,25 +408,25 @@ def main( write_experiment_log( json.dumps(TimingReport._timers), - f"{log_dir}/timing_report_world_size_{world_size}_cache_{use_cache}.json", + f"{log_dir}/{dataset}_timing_report_world_size_{world_size}_cache_{use_cache}.json", comm.get_rank(), ) visualize_trajectories( training_trajectores, "Training Loss", - f"{log_dir}/training_loss.png", + f"{log_dir}/{dataset}_training_loss.png", comm.get_rank(), ) visualize_trajectories( validation_trajectores, "Validation Loss", - f"{log_dir}/validation_loss.png", + f"{log_dir}/{dataset}_validation_loss.png", comm.get_rank(), ) visualize_trajectories( validation_accuracies, "Validation Accuracy", - f"{log_dir}/validation_accuracy.png", + f"{log_dir}/{dataset}_validation_accuracy.png", comm.get_rank(), ) cleanup() From b1fac8bc3a93eb21a14a5fd71f3649b02b1cd6dd Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Tue, 2 Sep 2025 21:24:12 -0700 Subject: [PATCH 8/8] Updated IGB cache files --- experiments/IGB/GenerateCache.py | 17 +----- experiments/IGB/IGB260MDataset.py | 94 ++++++++++++++++++++++++------- experiments/IGB/utils.py | 12 ++++ 3 files changed, 87 insertions(+), 36 deletions(-) diff --git a/experiments/IGB/GenerateCache.py b/experiments/IGB/GenerateCache.py index 562fedf..5c573e8 100644 --- a/experiments/IGB/GenerateCache.py +++ b/experiments/IGB/GenerateCache.py @@ -24,14 +24,7 @@ from time import perf_counter from tqdm import tqdm from multiprocessing import get_context - - -cache_prefix = { - "ogbn-arxiv": "arxiv", - "ogbn-products": "products", - "ogbn-papers100M": "papers100M", - "ogbn-proteins": "proteins", -} +from utils import DummyComm def generate_cache_file( @@ -85,12 +78,6 @@ def generate_cache_file( return 0 -class DummyComm: - def __init__(self, world_size: int): - self.world_size = world_size - self.rank = 0 - - def main(root, world_size: int, node_rank_placement_file=None): assert world_size > 0 @@ -125,7 +112,7 @@ def main(root, world_size: int, node_rank_placement_file=None): os.makedirs("cache", exist_ok=True) os.makedirs("cache/IGB", exist_ok=True) - with get_context("spawn").Pool(min(world_size, 8)) as pool: + with get_context("spawn").Pool(min(world_size, 2)) as pool: args = [ ( dist_graph, diff --git a/experiments/IGB/IGB260MDataset.py b/experiments/IGB/IGB260MDataset.py index a8c4d6f..6ec3f3d 100644 --- a/experiments/IGB/IGB260MDataset.py +++ b/experiments/IGB/IGB260MDataset.py @@ -3,6 +3,8 @@ import os.path as osp from DGraph.CommunicatorBase import CommunicatorBase from DGraph.data.ogbn_datasets import process_homogenous_data +import os +from utils import DummyComm def assign_node_rank(num_nodes, world_size): @@ -15,11 +17,17 @@ def assign_node_rank(num_nodes, world_size): return torch.from_numpy(arr).long() +def partitioned_saver(graph_obj, graph_file_path, rank): + # Only save what we need + torch.save(graph_obj, graph_file_path + f".part{rank}") + + class DistributedIGBWrapper: def __init__( self, root, comm, + graph_file_path=None, node_rank_placement=None, sim_node_features=True, num_features=1, @@ -29,21 +37,57 @@ def __init__( self.rank = comm.get_rank() self.world_size = comm.get_world_size() self.num_features = num_features - self.num_nodes = 227130858 + self.num_nodes = 269346174 self.num_edges = 3727095830 self.num_classes = 19 self.sim_node_features = sim_node_features if node_rank_placement is None: node_rank_placement = assign_node_rank(self.num_nodes, self.world_size) - self.load_graph_data(node_rank_placement) + print(f"Node Rank Placement shape: {node_rank_placement.shape}") + graph_file_path = ( + graph_file_path + if graph_file_path is not None + else osp.join("graph_cache", f"IGB260M_graph_data_{self.world_size}.pt") + ) + + self.graph_file_path = graph_file_path + + if os.path.exists(graph_file_path + f".part{self.rank}"): + print(f"Loading cached graph from {graph_file_path}.part{self.rank}") + self.graph_obj = torch.load( + graph_file_path + f".part{self.rank}", weights_only=False + ) + else: + if os.path.exists(graph_file_path): + print(f"Loading cached graph from {graph_file_path}") + graph_obj = torch.load(graph_file_path, weights_only=False) + self._load_slices_graph(graph_obj) + else: + print("Processing graph data") + self.load_graph_data(node_rank_placement) + + def _load_slices_graph(self, graph_obj): + print("Slicing and saving the graph data") + tensor_dict = { + "node_feat": graph_obj.get_local_node_features(rank=self.rank), + "edge_index": graph_obj.get_global_edge_indices(), + "rank_mapping": graph_obj.get_global_rank_mappings(), + "labels": graph_obj.get_local_labels(rank=self.rank), + } + + self.graph_obj = tensor_dict + os.makedirs("graph_cache", exist_ok=True) + partitioned_saver(self.graph_obj, self.graph_file_path, self.rank) def load_graph_data(self, node_rank_placement): processed_dir = osp.join(self.root, "processed") edge_dir = osp.join(processed_dir, "paper__cites__paper") node_features_dir = osp.join(processed_dir, "paper") edges = np.load(osp.join(edge_dir, "edge_index.npy"), mmap_mode="r") + edges = edges.T + print(edges.shape) - graph_data = {"edge_index": edges} + graph_data = {"edge_index": edges, "num_nodes": self.num_nodes} if self.sim_node_features: node_features = torch.randn( @@ -59,18 +103,20 @@ def load_graph_data(self, node_rank_placement): self.num_features = 1024 graph_data["node_feat"] = node_features + graph_data["edge_feat"] = None labels = np.memmap( osp.join(node_features_dir, "node_label_19.npy"), mode="r", dtype="float32", ) + print(labels.shape) n_train = int(self.num_nodes * 0.6) n_val = int(self.num_nodes * 0.2) - train_mask = torch.zeros(self.num_nodes, dtype=torch.bool) - val_mask = torch.zeros(self.num_nodes, dtype=torch.bool) - test_mask = torch.zeros(self.num_nodes, dtype=torch.bool) + train_mask = np.zeros(self.num_nodes, dtype=bool) + val_mask = np.zeros(self.num_nodes, dtype=bool) + test_mask = np.zeros(self.num_nodes, dtype=bool) train_mask[:n_train] = True val_mask[n_train : n_train + n_val] = True @@ -89,27 +135,33 @@ def load_graph_data(self, node_rank_placement): split_idx, node_rank_placement, ) - self.graph_obj = graph_obj + + self._load_slices_graph(graph_obj) def __len__(self) -> int: return 1 def __getitem__(self, idx: int): - rank = self.comm_object.get_rank() - local_node_features = self.graph_obj.get_local_node_features(rank=rank) - labels = self.graph_obj.get_local_labels(rank=rank) + local_node_features = self.graph_obj["node_feat"] + labels = self.graph_obj["labels"] + edge_indices = self.graph_obj["edge_index"] + rank_mappings = self.graph_obj["rank_mapping"] - # TODO: Move this to a backend-specific collator in the future - if self.comm_object.backend == "nccl": - # Return Graph object with Rank placement data + return local_node_features, edge_indices, rank_mappings, labels - # NOTE: Two-sided comm needs all the edge indices not the local ones - edge_indices = self.graph_obj.get_global_edge_indices() - rank_mappings = self.graph_obj.get_global_rank_mappings() - else: - # One-sided communication, no need for rank placement data - edge_indices = self.graph_obj.get_local_edge_indices(rank=rank) - rank_mappings = self.graph_obj.get_local_rank_mappings(rank=rank) +if __name__ == "__main__": + import argparse - return local_node_features, edge_indices, rank_mappings, labels + parser = argparse.ArgumentParser() + parser.add_argument("--root", type=str, default="data/IGB") + parser.add_argument("--world_size", type=int, default=2) + + args = parser.parse_args() + root = args.root + world_size = args.world_size + node_rank_placement_file = None + + for i in range(world_size): + comm = DummyComm(world_size, rank=i) + dataset = DistributedIGBWrapper(root, comm) diff --git a/experiments/IGB/utils.py b/experiments/IGB/utils.py index 30fe867..bfc1aa3 100644 --- a/experiments/IGB/utils.py +++ b/experiments/IGB/utils.py @@ -4,6 +4,18 @@ import os +class DummyComm: + def __init__(self, world_size: int, rank: int = 0): + self.world_size = world_size + self.rank = rank + + def get_world_size(self): + return self.world_size + + def get_rank(self): + return self.rank + + def cleanup(): if dist.is_initialized(): dist.destroy_process_group()