Skip to content

Commit f5e43aa

Browse files
committed
Added GAT implementation
1 parent 02b7f79 commit f5e43aa

File tree

1 file changed

+164
-0
lines changed
  • applications/graph/NodePropPrediction

1 file changed

+164
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import lbann
2+
from lbann.modules import Module, ChannelwiseFullyConnectedModule, ConvolutionModule
3+
import lbann.modules
4+
import math
5+
6+
7+
def ContractHeads(lbann_graph_layer, shape):
8+
"""
9+
A utility function that contracts the rows of a (N, M, H) matrix to an (N, M) matrix using grouped 2D convolution.
10+
The contration computes the average along the first dimension so the output is scaled by 1 / H.
11+
12+
Args:
13+
lbann_graph_layer (layer): Graph layer tensor with shape (N, M, H)
14+
15+
shape ((int, int, int)): Shape of graph layer tensor
16+
17+
Returns:
18+
(Layer): Contracted and rescaled output with shape (N, M)
19+
"""
20+
num_nodes, output_channels, num_heads = shape
21+
weights = lbann.Weights(
22+
initializer=lbann.ConstantInitializer(value=1 / num_heads),
23+
optimizer=lbann.NoOptimizer(),
24+
)
25+
kernel_shape = (1, num_heads)
26+
contraction = lbann.Convolution(
27+
num_dims=2,
28+
output_channels=num_nodes,
29+
kernel_size=kernel_shape,
30+
stride=1,
31+
padding=0,
32+
groups=num_nodes,
33+
has_bias=False,
34+
weights=weights,
35+
)
36+
output = lbann.Reshape(contraction, dims=[num_nodes, output_channels])
37+
return output
38+
39+
40+
class GAT(Module):
41+
"""Graph Attention Network layer. For kernel details, see:
42+
43+
https://arxiv.org/abs/1710.10903
44+
45+
"""
46+
47+
global_count = 0
48+
49+
def __init__(
50+
self,
51+
input_channels,
52+
output_channels,
53+
num_nodes,
54+
num_edges,
55+
num_heads=1,
56+
name=None,
57+
):
58+
"""Initialize GatedGraph layer
59+
Args:
60+
input_channels (int): The size of the input node features
61+
output_channels (int): The output size of the node features
62+
num_nodes (int): Number of vertices in the graph
63+
num_edges (int): Number of edges in the graph
64+
num_heads (int): Number of attention heads (default: 1)
65+
name (str): Name of the layers and prefix to use for the layers.
66+
data_layout (str): Data layout (default: data parallel)
67+
"""
68+
super().__init__()
69+
70+
# Add Name for the components for the layer
71+
GAT.global_count += 1
72+
self.name = name if name else "GAT_{}".format(GAT.global_count)
73+
# Add variables
74+
self.output_channel_size = output_channels
75+
self.input_channel_size = input_channels
76+
self.num_nodes = num_nodes
77+
self.num_edges = num_edges
78+
self.num_heads = num_heads
79+
80+
weights = lbann.Weights(
81+
initializer=lbann.UniformInitializer(
82+
min=-1 / (math.sqrt(output_channels)),
83+
max=1 / (math.sqrt(output_channels)),
84+
)
85+
)
86+
self.W_k = ChannelwiseFullyConnectedModule(
87+
self.output_channel_size * num_heads,
88+
bias=False,
89+
weights=[weights],
90+
name=f"{self.name}_nn_{1}",
91+
)
92+
93+
self.a_vec = ConvolutionModule(
94+
num_dims=1,
95+
out_channels=self.num_nodes,
96+
kernel_size=[2 * self.output_channel_size, 1],
97+
groups=self.num_nodes,
98+
bias=False,
99+
name=f"{self.name}_nn_{2}",
100+
)
101+
102+
def forward(
103+
self, node_feature_mat, source_indices, target_indices, reduction="concat"
104+
):
105+
"""Call GATGraphConv
106+
Args:
107+
node_feature_mat (Layer): Node feature matrix with the shape of (num_nodes, input_channels)
108+
source_indices (Layer): Source node indices of the edges with shape (num_edges)
109+
target_indices (Layer): Target node indices of the edges with shape (num_edges)
110+
reduction (string: [concat| average]): The type of reductions to use for multiple heads
111+
Returns:
112+
(Layer) : The output after kernel ops. The shape of the layer is
113+
(num_nodes, num_heads * num_output_channels) if reduction is "concat"
114+
(num_nodes, num_output_channels) if reduction is "average"
115+
"""
116+
# (N x [self.output_channel * self.num_heads])
117+
transform_node_features = self.W_nn(
118+
node_feature_mat, name=f"{self.name}_transform"
119+
)
120+
# (E x [self.output_channel * self.num_heads])
121+
e_i = lbann.Gather(transform_node_features, source_indices, axis=0)
122+
e_j = lbann.Gather(transform_node_features, target_indices, axis=0)
123+
# (E x self.output_channel x self.num_heads)
124+
e_i = lbann.Reshape(
125+
e_i, dims=[self.num_edges, self.output_channel_size, self.num_heads]
126+
)
127+
e_j = lbann.Reshape(
128+
e_j, dims=[self.num_edges, self.output_channel_size, self.num_heads]
129+
)
130+
# (E x 2 * self.output_channel x self.num_heads)
131+
messages = lbann.Concatenation([e_i, e_j], axis=1)
132+
# (E x self.num_heads)
133+
m_ij = lbann.Reshape(
134+
self.a_vec(messages), dims=[self.num_edges, self.num_heads]
135+
)
136+
m_ij = lbann.ExpOperator(lbann.LeakyRelu(m_ij, negative_slope=0.02))
137+
# (N x self.num_heads)
138+
contraction = lbann.Scatter(m_ij, target_indices, axis=0)
139+
# (N x 1 x self.num_heads)
140+
broadcast = lbann.Reshape(contraction, dims=[self.num_nodes, 1, self.num_heads])
141+
# (E x 1 x self.num_heads)
142+
broadcast = lbann.Gather(broadcast, target_indices, axis=1)
143+
# (E x self.output_channel_size x self.num_heads)
144+
broadcast = lbann.Tessellate(
145+
broadcast, dims=[self.num_edges, self.output_channel_size, self.num_heads]
146+
)
147+
# (E x self.output_channel_size x self.num_heads)
148+
normalize = lbann.Scatter(broadcast, source_indices, axis=0)
149+
alpha_ij = lbann.DivideOperator(m_ij, normalize)
150+
151+
h_ij = lbann.MultiplyOperator(alpha_ij, e_j)
152+
153+
h_i = lbann.Scatter(h_ij, source_indices)
154+
155+
if reduction.lower() == "concat":
156+
node_feature_mat = lbann.Reshape(h_i)
157+
elif reduction.lower() == "average":
158+
node_feature_mat = ContractHeads(
159+
h_i, (self.num_nodes, self.output_channel_size, self.num_heads)
160+
)
161+
else:
162+
raise ValueError("Expected reduction arguments are: concat or average")
163+
164+
return node_feature_mat

0 commit comments

Comments
 (0)