forked from mtiezzi/torch_gnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_simple.py
More file actions
135 lines (101 loc) · 3.53 KB
/
Copy pathmain_simple.py
File metadata and controls
135 lines (101 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import numpy as np
import argparse
import utils
import dataloader
from gnn_wrapper import GNNWrapper, SemiSupGNNWrapper
import matplotlib.pyplot as plt
import networkx as nx
import net
# # #plotting
# # print performance of one method for now, to be expanded to plot several methods at the same time
# def plot_performance(x_points, y_points,title):
#
# plt.plot(x_points, y_points)
#
# plt.xlabel('#epoch')
# plt.ylabel('Accuracy')
# plt.title(title)
#
# #plt.legend()
# plt.show()
# # #/plotting
# GRAPH #1
# List of edges in the first graph - last column is the id of the graph to which the arc belongs
e = [[0, 1, 0], [0,2, 0], [0, 4, 0], [1, 2, 0], [1, 3, 0], [2, 3, 0], [2, 4, 0]]
# undirected graph, adding other direction
e.extend([[i, j, num] for j, i, num in e])
#reorder
e = sorted(e)
E = np.asarray(e)
#number of nodes
edges = 5
# creating node features - simply one-hot values
N = np.eye(edges, dtype=np.float32)
# adding column thta represent the id of the graph to which the node belongs
N = np.concatenate((N, np.zeros((edges,1), dtype=np.float32)), axis=1 )
# visualization graph
def plot_graph(E, N):
g = nx.Graph()
g.add_nodes_from(range(N.shape[0]))
g.add_edges_from(E[:, :2])
nx.draw_spring(g, cmap=plt.get_cmap('Set1'), with_labels=True)
plt.show()
plot_graph(E,N)
# GRAPH #2
# List of edges in the second graph - last column graph-id
e1 = [[0, 2, 1], [0,3,1], [1, 2,1], [1,3,1], [2,3,1]]
# undirected graph, adding other direction
e1.extend([[i, j, num] for j, i, num in e1])
# reindexing node ids based on the dimension of previous graph (using unique ids)
e2 = [[a + N.shape[0], b + N.shape[0], num] for a, b, num in e1]
#reorder
e2 = sorted(e2)
edges_2 = 4
# Plot second graph
E1 = np.asarray(e1)
N1 = np.eye(edges_2, dtype=np.float32)
N1 = np.concatenate((N1, np.zeros((edges_2,1), dtype=np.float32)), axis=1 )
plot_graph(E1,N1)
E = np.concatenate((E, np.asarray(e2)), axis=0)
N_tot = np.eye(edges + edges_2, dtype=np.float32)
N_tot = np.concatenate((N_tot, np.zeros((edges + edges_2,1), dtype=np.float32)), axis=1 )
# Create Input to GNN
labels = np.random.randint(2, size=(N_tot.shape[0]))
#labels = np.eye(max(labels)+1, dtype=np.int32)[labels] # one-hot encoding of labels
cfg = GNNWrapper.Config()
cfg.use_cuda = True
cfg.device = utils.prepare_device(n_gpu_use=1, gpu_id=0)
cfg.tensorboard = False
cfg.epochs = 500
cfg.activation = nn.Tanh()
cfg.state_transition_hidden_dims = [5,]
cfg.output_function_hidden_dims = [5]
cfg.state_dim = 5
cfg.max_iterations = 50
cfg.convergence_threshold = 0.01
cfg.graph_based = False
cfg.log_interval = 10
cfg.task_type = "multiclass"
cfg.lrw = 0.001
dset = dataloader.from_EN_to_GNN(E, N_tot, targets=labels, aggregation_type="sum", sparse_matrix=True) # generate the dataset
cfg.label_dim = dset.node_label_dim
cfg.state_net = net.GINTransition(node_state_dim=cfg.state_dim,
node_label_dim=cfg.label_dim,
mlp_hidden_dim= cfg.state_transition_hidden_dims,
activation_function=cfg.activation
)
# model creation
model = GNNWrapper(cfg)
# dataset creation
model(dset,state_net=cfg.state_net) # dataset initalization into the GNN
# training code
for epoch in range(1, cfg.epochs + 1):
model.train_step(epoch)
if epoch % 10 == 0:
model.test_step(epoch)
# # #plotting
# plot_performance(model.arr_its_train,model.arr_acc_train, "Train set Accuracy ")
# plot_performance(model.arr_its_test,model.arr_acc_test, "Test set Accuracy")
# # #/plotting