Skip to content

Commit 160c291

Browse files
committed
Added trainer files
- Add GCN implementation with loops - Add dataset wrapper for arxiv - To do: Add MAG dataset - rebased
1 parent 62f358b commit 160c291

File tree

4 files changed

+251
-0
lines changed

4 files changed

+251
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import lbann
2+
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
3+
import lbann.modules
4+
5+
6+
class GCN(Module):
7+
"""
8+
Graph convolutional kernel
9+
"""
10+
11+
def __init__(
12+
self,
13+
num_nodes,
14+
num_edges,
15+
input_features,
16+
output_features,
17+
activation=lbann.Relu,
18+
distconv_enabled=True,
19+
num_groups=4,
20+
):
21+
super().__init__()
22+
self._input_dims = input_features
23+
self._output_dims = output_features
24+
self._num_nodes = num_nodes
25+
self._num_edges = num_edges
26+
27+
def forward(self, node_features, source_indices, target_indices):
28+
x = lbann.Gather(node_features, target_indices, axis=0)
29+
x = lbann.ChannelwiseFullyConnected(x, output_channel_dims=self._output_dims)
30+
x = self._activation(x)
31+
x = lbann.Scatter(x, source_indices, dims=self._ft_dims)
32+
return x
33+
34+
35+
def create_model(num_nodes, num_edges, input_features, output_features, num_layers=3):
36+
"""
37+
Create a GCN model
38+
"""
39+
# Layer graph
40+
input_ = lbann.Input()
41+
split_indices = [0, num_nodes * input_features]
42+
split_indices += [split_indices[-1] + num_edges]
43+
split_indices += [split_indices[-1] + num_edges]
44+
split_indices += [split_indices[-1] + num_nodes]
45+
46+
node_features = lbann.Reshape(
47+
lbann.Identity(input_), dims=[num_nodes, input_features]
48+
)
49+
50+
source_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
51+
target_indices = lbann.Reshape(lbann.Identity(input_), dims=[num_edges])
52+
label = lbann.Reshape(lbann.Identity(input_), dims=[num_nodes])
53+
54+
x = GCN(
55+
num_nodes,
56+
num_edges,
57+
input_features,
58+
output_features,
59+
activation=lbann.Relu,
60+
distconv_enabled=False,
61+
num_groups=4,
62+
)(node_features, source_indices, target_indices)
63+
64+
for _ in range(num_layers - 1):
65+
x = GCN(
66+
num_nodes,
67+
num_edges,
68+
input_features,
69+
output_features,
70+
activation=lbann.Relu,
71+
distconv_enabled=False,
72+
num_groups=4,
73+
)(x, source_indices, target_indices)
74+
75+
# Loss function
76+
loss = lbann.CrossEntropy([x, label])
77+
78+
# Metrics
79+
acc = lbann.CategoricalAccuracy([x, label])
80+
81+
# Callbacks
82+
callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
83+
84+
# Construct model
85+
return lbann.Model(
86+
num_epochs=1,
87+
layers=lbann.traverse_layer_graph(input_),
88+
objective_function=loss,
89+
metrics=[acc],
90+
callbacks=callbacks,
91+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import lbann
2+
import os.path as osp
3+
4+
5+
current_dir = osp.dirname(osp.realpath(__file__))
6+
7+
DATASET_CONFIG = {
8+
"ARXIV": {
9+
"num_nodes": 169343,
10+
"num_edges": 1166243,
11+
"input_features": 128,
12+
}
13+
}
14+
15+
16+
def make_data_reader(dataset):
17+
reader = lbann.reader_pb2.DataReader()
18+
reader.name = "python"
19+
reader.role = "train"
20+
reader.shuffle = True
21+
reader.percent_of_data_to_use = 1.0
22+
reader.python.module = f"{dataset}_dataset"
23+
reader.python.module_dir = osp.join(current_dir, "datasets")
24+
reader.python.sample_function = "get_train_sample"
25+
reader.python.num_samples_function = "num_train_samples"
26+
reader.python.sample_dims_function = "sample_dims"
27+
return reader
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import os
3+
4+
5+
# load the dataset
6+
7+
data_dir = "/p/vast1/lbann/datasets/OpenGraphBenchmarks/dataset/ogbn_arxiv"
8+
9+
connectivity_data = np.load(data_dir + "/edges.npy")
10+
node_data = (
11+
np.load(data_dir + "/node_feats.npy")
12+
if os.path.exists(data_dir + "/node_feats.npy")
13+
else np.random.rand(169343, 128) # random node features
14+
)
15+
16+
labels_data = (
17+
np.load(data_dir + "/labels.npy")
18+
if os.path.exists(data_dir + "/labels.npy")
19+
else np.random.randint(0, 40, 169343) # random labels
20+
)
21+
22+
num_edges = 1166243
23+
num_nodes = 169343
24+
25+
assert connectivity_data.shape == (num_edges, 2)
26+
assert node_data.shape == (num_nodes, 128)
27+
28+
29+
def get_train_sample(index):
30+
# Return the complete node data
31+
return node_data.flatten() + connectivity_data.flatten() + labels_data.flatten()
32+
33+
34+
def sample_dims():
35+
return (
36+
np.reduce(node_data.shape, lambda x, y: x * y)
37+
+ np.reduce(connectivity_data.shape, lambda x, y: x * y),
38+
)
39+
40+
41+
def num_train_samples():
42+
return 1
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from dataset_wrapper import DATASET_CONFIG
2+
import lbann
3+
import lbann.contrib.launcher
4+
import lbann.contrib.args
5+
6+
import argparse
7+
8+
desc = " Training a Graph Convolutional Model using LBANN"
9+
parser = argparse.ArgumentParser(description=desc)
10+
11+
lbann.contrib.args.add_scheduler_arguments(parser, "GNN")
12+
lbann.contrib.args.add_optimizer_arguments(parser)
13+
14+
parser.add_argument(
15+
"--num-epochs",
16+
action="store",
17+
default=100,
18+
type=int,
19+
help="number of epochs (deafult: 100)",
20+
metavar="NUM",
21+
)
22+
23+
parser.add_argument(
24+
"--model",
25+
action="store",
26+
default="GCN",
27+
type=str,
28+
help="The type of model to use",
29+
metavar="NAME",
30+
)
31+
32+
parser.add_argument(
33+
"--dataset",
34+
action="store",
35+
default="ARXIV",
36+
type=str,
37+
help="The dataset to use",
38+
metavar="NAME",
39+
)
40+
41+
parser.add_argument(
42+
"--latent-dim",
43+
action="store",
44+
default=16,
45+
type=int,
46+
help="The latent dimension of the model",
47+
metavar="NUM",
48+
)
49+
50+
parser.add_argument(
51+
"--num-layers",
52+
action="store",
53+
default=3,
54+
type=int,
55+
help="The number of layers in the model",
56+
metavar="NUM",
57+
)
58+
59+
60+
SUPPORTED_MODELS = ["GCN", "GAT"]
61+
SUPPORTED_DATASETS = ["ARXIV", "PRODUCTS", "MAG240M"]
62+
63+
64+
def main():
65+
args = parser.parse_args()
66+
67+
kwargs = lbann.contrib.args.get_scheduler_kwargs(args)
68+
69+
num_epochs = args.num_epochs
70+
mini_batch_size = 1
71+
job_name = args.job_name
72+
model_arch = args.model
73+
dataset = args.dataset
74+
75+
if model_arch not in SUPPORTED_MODELS:
76+
raise ValueError(
77+
f"Model {model_arch} not supported. Supported models are {SUPPORTED_MODELS}"
78+
)
79+
80+
if dataset not in SUPPORTED_DATASETS:
81+
raise ValueError(
82+
f"Dataset {dataset} not supported. Supported datasets are {SUPPORTED_DATASETS}"
83+
)
84+
dataset_config = DATASET_CONFIG[dataset]
85+
num_nodes = dataset_config["num_nodes"]
86+
num_edges = dataset_config["num_edges"]
87+
input_features = dataset_config["input_features"]
88+
89+
90+
optimizer = lbann.SGD(learn_rate=0.01, momentum=0.0, eps=1e-8)
91+
lbann.contrib.launcher.run(trainer, model, data_reader, opt, **kwargs)

0 commit comments

Comments
 (0)