diff --git a/applications/graph/MeshGraphNet/GNN.py b/applications/graph/MeshGraphNet/GNN.py new file mode 100644 index 00000000000..0676f843c7b --- /dev/null +++ b/applications/graph/MeshGraphNet/GNN.py @@ -0,0 +1,156 @@ +import lbann +from GNNComponents import MLP, GraphProcessor + + +def input_data_splitter( + input_layer, num_nodes, num_edges, in_dim_node, in_dim_edge, out_dim +): + """Takes a flattened sample from the Python DataReader and slices + them according to the graph attributes. + """ + + split_indices = [] + start_index = 0 + node_feature_size = num_nodes * in_dim_node + edge_feature_size = num_edges * in_dim_edge + out_feature_size = num_nodes * out_dim + + split_indices.append(start_index) + split_indices.append(split_indices[-1] + node_feature_size) + split_indices.append(split_indices[-1] + edge_feature_size) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + num_edges) + split_indices.append(split_indices[-1] + out_feature_size) + + sliced_input = lbann.Slice(input_layer, axis=0, slice_points=split_indices) + + node_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_nodes, in_dim_node] + ) + edge_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_edges, in_dim_edge] + ) + source_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + target_node_indices = lbann.Reshape(lbann.Identity(sliced_input), dims=[num_edges]) + + out_features = lbann.Reshape( + lbann.Identity(sliced_input), dims=[num_nodes, out_dim] + ) + + return ( + node_features, + edge_features, + source_node_indices, + target_node_indices, + out_features, + ) + + +def LBANN_GNN_Model( + num_nodes, + num_edges, + in_dim_node, + in_dim_edge, + out_dim, + out_dim_node=128, + out_dim_edge=128, + hidden_dim_node=128, + hidden_dim_edge=128, + hidden_layers_node=2, + hidden_layers_edge=2, + mp_iterations=15, + hidden_dim_processor_node=128, + hidden_dim_processor_edge=128, + hidden_layers_processor_node=2, + hidden_layers_processor_edge=2, + norm_type=lbann.LayerNorm, + hidden_dim_decoder=128, + hidden_layers_decoder=2, + num_epochs=10, +): + # Set up model modules and associated weights + + node_encoder = MLP( + in_dim=in_dim_node, + out_dim=out_dim_node, + hidden_dim=hidden_dim_node, + hidden_layers=hidden_layers_node, + norm_type=norm_type, + name="graph_input_node_encoder", + ) + + edge_encoder = MLP( + in_dim=in_dim_edge, + out_dim=out_dim_edge, + hidden_dim=hidden_dim_edge, + hidden_layers=hidden_layers_edge, + norm_type=norm_type, + name="graph_input_edge_encoder", + ) + + # The graph processor currently only implements homogenous node graphs + # so we do not distinguish between world and mesh nodes. LBANN supports + # heterogenous and multi-graphs in general + + # We also disable adaptive remeshing as that may require recomputing + # the compute graph due to changing graph characteristics + graph_processor = GraphProcessor( + num_nodes=num_nodes, + mp_iterations=mp_iterations, + in_dim_node=out_dim_node, + in_dim_edge=out_dim_edge, + hidden_dim_node=hidden_dim_processor_node, + hidden_dim_edge=hidden_dim_processor_edge, + hidden_layers_node=hidden_layers_processor_node, + hidden_layers_edge=hidden_layers_processor_edge, + norm_type=norm_type, + ) + + node_decoder = MLP( + in_dim=out_dim_node, + out_dim=out_dim, + hidden_dim=hidden_dim_decoder, + hidden_layers=hidden_layers_decoder, + norm_type=None, + name="graph_input_node_decoder", + ) + + # Define LBANN Compute graph + + input_layer = lbann.Input(data_field="samples") + + ( + node_features, + edge_features, + source_node_indices, + target_node_indices, + out_features, + ) = input_data_splitter( + input_layer, num_nodes, num_edges, in_dim_node, in_dim_edge, out_dim + ) + + node_features = node_encoder(node_features) + edge_features = edge_encoder(edge_features) + + node_features, _ = graph_processor( + node_features, edge_features, source_node_indices, target_node_indices + ) + + calculated_features = node_decoder(node_features) + + loss = lbann.MeanSquaredError(calculated_features, out_features) + + # Define some of the usual callbacks + + training_output = lbann.CallbackPrint(interval=1, print_global_stat_only=False) + gpu_usage = lbann.CallbackGPUMemoryUsage() + timer = lbann.CallbackTimer() + callbacks = [training_output, gpu_usage, timer] + + # Putting it all together and compile the model + + layers = lbann.traverse_layer_graph(input_layer) + model = lbann.Model( + num_epochs, layers=layers, objective_function=loss, callbacks=callbacks + ) + return model diff --git a/applications/graph/MeshGraphNet/GNNComponents.py b/applications/graph/MeshGraphNet/GNNComponents.py new file mode 100644 index 00000000000..f99440d2a60 --- /dev/null +++ b/applications/graph/MeshGraphNet/GNNComponents.py @@ -0,0 +1,264 @@ +import lbann +from lbann.modules import Module, ChannelwiseFullyConnectedModule + + +class MLP(Module): + """ + Applies channelwise MLP with ReLU activation with Layer Normalization + with a specified number of hidden layers + """ + + global_count = 0 + + def __init__( + self, + in_dim, + out_dim, + hidden_dim, + hidden_layers, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + MLP.global_count += 1 + + self.instance = 0 + self.in_dim = in_dim + self.out_dim = out_dim + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + + self.name = name if name else f"MLP_{MLP.global_count}" + + self.layers = [ + ChannelwiseFullyConnectedModule( + hidden_dim, bias=True, activation=lbann.Relu + ) + ] + for i in range(hidden_layers): + # Total number of MLPs is hidden layers + 2 (input and output) + self.layers.append( + ChannelwiseFullyConnectedModule( + hidden_dim, bias=True, activation=lbann.Relu + ) + ) + + self.layers.append( + ChannelwiseFullyConnectedModule(out_dim, bias=True, activation=None) + ) + + self.norm_type = None + + if norm_type: + if isinstance(norm_type, type): + self.norm_type = norm_type + else: + self.norm_type = type(norm_type) + + if not issubclass(norm_type, lbann.Layer): + raise ValueError("Normalization must be a layer") + + def forward(self, x): + """ + Args: + x (Layer) : Expected shape (Batch, N, self.in_dim) + + Returns: + (Layer): Expected shape (Batch, N, self.out_dim) + """ + self.instance += 1 + name = f"{self.name}_instance_{self.instance}" + + for layer in self.layers: + x = layer(x) + + if self.norm_type: + return self.norm_type(x, name=name+"_norm") + return x + + +class EdgeProcessor(Module): + """Applies MLP transform on concatenated node and edge features""" + + global_count = 0 + + def __init__( + self, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + EdgeProcessor.global_count += 1 + self.instance = 0 + self.name = name if name else f"EdgeProcessor_{EdgeProcessor.global_count}" + + self.edge_mlp = MLP( + 2 * in_dim_node + in_dim_edge, + in_dim_edge, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_edge_mlp", + ) + + def forward( + self, + node_features, + edge_features, + source_node_indices, + target_node_indices, + ): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + + Returns: + (Layer): Expected shape (Batch, Num_edges, self.in_dim_edge) + """ + self.instance += 1 + source_node_features = lbann.Gather(node_features, source_node_indices, axis=0) + target_node_features = lbann.Gather(node_features, target_node_indices, axis=0) + + x = lbann.Concatenation( + [source_node_features, target_node_features, edge_features], + axis=1, + name=f"{self.name}_{self.instance}_concat_features", + ) + x = self.edge_mlp(x) + + return lbann.Sum( + edge_features, x, name=f"{self.name}_{self.instance}_residual_sum" + ) + + +class NodeProcessor(Module): + """Applies MLP transform on scatter-summed edge features and node features""" + + global_count = 0 + + def __init__( + self, + num_nodes, + in_dim_node=128, + in_dim_edge=128, + hidden_dim=128, + hidden_layers=2, + norm_type=lbann.LayerNorm, + name=None, + ): + super().__init__() + NodeProcessor.global_count += 1 + self.instance = 0 + self.name = name if name else f"NodeProcessor_{NodeProcessor.global_count}" + self.num_nodes = num_nodes + self.in_dim_edge = in_dim_edge + self.node_mlp = MLP( + in_dim_node + in_dim_edge, + in_dim_node, + hidden_dim=hidden_dim, + hidden_layers=hidden_layers, + norm_type=norm_type, + name=f"{self.name}_node_mlp", + ) + + def forward(self, node_features, edge_features, target_edge_indices): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, num_edges, self.in_dim_edge) + edge_indices (Layer): Expected shape (Batch, num_edges) + Returns: + (Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) + """ + self.instance += 1 + name = f"{self.name}_{self.instance}" + edge_feature_sum = lbann.Scatter( + edge_features, + target_edge_indices, + name=f"{name}_scatter", + dims=[self.num_nodes, self.in_dim_edge], + axis=0, + ) + + x = lbann.Concatenation( + [node_features, edge_feature_sum], + axis=1, + name=f"{name}_concat_features", + ) + x = self.node_mlp(x) + + return lbann.Sum( + node_features, x, name=f"{name}_residual_sum" + ) + + +class GraphProcessor(Module): + """Graph processor module""" + + def __init__( + self, + num_nodes, + mp_iterations=15, + in_dim_node=128, + in_dim_edge=128, + hidden_dim_node=128, + hidden_dim_edge=128, + hidden_layers_node=2, + hidden_layers_edge=2, + norm_type=lbann.LayerNorm, + ): + super().__init__() + + self.blocks = [] + + for _ in range(mp_iterations): + node_processor = NodeProcessor( + num_nodes=num_nodes, + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_node, + hidden_layers=hidden_layers_node, + norm_type=norm_type, + ) + + edge_processor = EdgeProcessor( + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + hidden_dim=hidden_dim_edge, + hidden_layers=hidden_layers_edge, + norm_type=norm_type, + ) + + self.blocks.append((node_processor, edge_processor)) + + def forward( + self, node_features, edge_features, source_node_indices, target_node_indices + ): + """ + Args: + node_features (Layer) : Expected shape (Batch, num_nodes, self.in_dim_node) + edge_features (Layer) : Expected shape (Batch, Num_edges, self.in_dim_edge) + source_node_indices (Layer) : Expected shape (Batch, num_edges) + target_node_indices (Layer) : Expected shape (Batch, num_edges) + Returns: + (Layer, Layer): Expected shape (Batch, Num_nodes, self.in_dim_node) and + (Batch, num_edges, self.in_dim_edge) + """ + + for node_processor, edge_processor in self.blocks: + e = edge_processor( + node_features, edge_features, source_node_indices, target_node_indices + ) + edge_features = e + x = node_processor(node_features, edge_features, target_node_indices) + + node_features = x + + return node_features, edge_features diff --git a/applications/graph/MeshGraphNet/README.md b/applications/graph/MeshGraphNet/README.md new file mode 100644 index 00000000000..76d771958a5 --- /dev/null +++ b/applications/graph/MeshGraphNet/README.md @@ -0,0 +1,19 @@ +## Mesh Graph Networks + +This example contains an LBANN implementation of mesh-based graph neural network (MeshGraphNet) with +synthetically generated data. +For more information about the model, refer to: T. Pfaff et al., "Learning Mesh-Based Simulation with Graph Networks". ICLR'21. + +--- +### Running the example + +The data-parallel model can be run with the synthetic data with: + +```bash +python Trainer.py --mini-batch-size --num-epochs +``` + +### Notes + +- This implementation does not distinguish between world nodes and mesh nodes +- We do not currently implement adaptive remeshing, as this may require updating the compute graph after each mini-batch \ No newline at end of file diff --git a/applications/graph/MeshGraphNet/SyntheticData.py b/applications/graph/MeshGraphNet/SyntheticData.py new file mode 100644 index 00000000000..47283f0eb13 --- /dev/null +++ b/applications/graph/MeshGraphNet/SyntheticData.py @@ -0,0 +1,36 @@ +import numpy as np +import configparser + + +DATA_CONFIG = configparser.ConfigParser() +DATA_CONFIG.read("data_config.ini") +NUM_NODES = 100 # int(DATA_CONFIG['DEFAULT']['NUM_NODES']) +NUM_EDGES = 10000 # int(DATA_CONFIG['DEFAULT']['NUM_EDGES']) +NODE_FEATS = 5 # int(DATA_CONFIG['DEFAULT']['NODE_FEATURES']) +EDGE_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES']) +OUT_FEATS = 3 # int(DATA_CONFIG['DEFAULT']['OUT_FEATURES']) +NUM_SAMPLES = 100 + + + +NODE_FEATURE_SIZE = NUM_NODES * NODE_FEATS +EDGE_FEATURE_SIZE = NUM_EDGES * EDGE_FEATS +OUT_FEATURE_SIZE = NUM_EDGES * OUT_FEATS + +def get_sample_func(index): + random_features = np.random.random(NODE_FEATURE_SIZE+OUT_FEATURE_SIZE).astype(np.float32) + source_indices = np.random.randint(-1, NUM_NODES, size=NUM_EDGES).astype(np.float32) + target_indices = np.random.randint(-1, NUM_NODES, size=NUM_EDGES).astype(np.float32) + out_features = np.random.random(EDGE_FEATURE_SIZE).astype(np.float32) + + return np.concatenate([random_features, source_indices, target_indices, out_features]) + +def num_samples_func(): + return NUM_SAMPLES + +def sample_dims_func(): + + size = NODE_FEATURE_SIZE + EDGE_FEATURE_SIZE + OUT_FEATURE_SIZE + 2 * NUM_EDGES + return (size, ) + + diff --git a/applications/graph/MeshGraphNet/Trainer.py b/applications/graph/MeshGraphNet/Trainer.py new file mode 100644 index 00000000000..4b829d12137 --- /dev/null +++ b/applications/graph/MeshGraphNet/Trainer.py @@ -0,0 +1,90 @@ +import lbann +import lbann.contrib.launcher +import lbann.contrib.args +import argparse +import configparser +import os.path as osp +from GNN import LBANN_GNN_Model + +data_dir = osp.dirname(osp.realpath(__file__)) + + +desc = ("Training a Mesh Graph Neural Network Model Using LBANN") + +parser = argparse.ArgumentParser(description=desc) + +lbann.contrib.args.add_scheduler_arguments(parser) +lbann.contrib.args.add_optimizer_arguments(parser) + +parser.add_argument( + '--num-epochs', action='store', default=3, type=int, + help='number of epochs (deafult: 3)', metavar='NUM') + +parser.add_argument( + '--mini-batch-size', action='store', default=4, type=int, + help="mini-batch size (default: 4)", metavar='NUM') + +parser.add_argument( + '--job-name', action='store', default="MGN", type=str, + help="Job name for scheduler", metavar='NAME') + +args = parser.parse_args() +kwargs = lbann.contrib.args.get_scheduler_kwargs(args) + +# Some training parameters + +MINI_BATCH_SIZE = args.mini_batch_size +NUM_EPOCHS = args.num_epochs +JOB_NAME = args.job_name + +# Some synthetic attributes to get the model running +DATA_CONFIG = configparser.ConfigParser() +DATA_CONFIG.read("data_config.ini") + +NUM_NODES = int(DATA_CONFIG['DEFAULT']['NUM_NODES']) +NUM_EDGES = int(DATA_CONFIG['DEFAULT']['NUM_EDGES']) +NODE_FEATS = int(DATA_CONFIG['DEFAULT']['NODE_FEATURES']) +EDGE_FEATS = int(DATA_CONFIG['DEFAULT']['EDGE_FEATURES']) +OUT_FEATS = int(DATA_CONFIG['DEFAULT']['OUT_FEATURES']) + +def make_data_reader(classname, + sample='get_sample_func', + num_samples='num_samples_func', + sample_dims='sample_dims_func', + validation_percent=0.1): + reader = lbann.reader_pb2.DataReader() + _reader = reader.reader.add() + _reader.name = 'python' + _reader.role = 'train' + _reader.shuffle = True + _reader.percent_of_data_to_use = 1.0 + _reader.validation_percent = validation_percent + _reader.python.module = classname + _reader.python.module_dir = data_dir + _reader.python.sample_function = sample + _reader.python.num_samples_function = num_samples + _reader.python.sample_dims_function = sample_dims + return reader + +def main(): + # Use the defaults for the other parameters + model = LBANN_GNN_Model(num_nodes=NUM_NODES, + num_edges=NUM_EDGES, + in_dim_node=NODE_FEATS, + in_dim_edge=EDGE_FEATS, + out_dim=OUT_FEATS, + num_epochs=NUM_EPOCHS) + + optimizer = lbann.SGD(learn_rate=1e-4) + data_reader = make_data_reader("SyntheticData") + trainer = lbann.Trainer(mini_batch_size=MINI_BATCH_SIZE) + + lbann.contrib.launcher.run(trainer, + model, + data_reader, + optimizer, + job_name=JOB_NAME, + **kwargs) + +if __name__ == '__main__': + main() diff --git a/applications/graph/MeshGraphNet/data_config.ini b/applications/graph/MeshGraphNet/data_config.ini new file mode 100644 index 00000000000..0d641e03b4a --- /dev/null +++ b/applications/graph/MeshGraphNet/data_config.ini @@ -0,0 +1,6 @@ +[DEFAULT] +NUM_NODES = 100 +NUM_EDGES = 10000 +EDGE_FEATURES = 3 +NODE_FEATURES = 5 +OUT_FEATURES = 3