Skip to content

Commit 81690d2

Browse files
authored
fix message passing (#95)
1 parent ce87359 commit 81690d2

File tree

2 files changed

+92
-3
lines changed

2 files changed

+92
-3
lines changed

gns/graph_network.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ def forward(
187187
# Start propagating messages.
188188
# Takes in the edge indices and all additional data which is needed to
189189
# construct messages and to update node embeddings.
190+
# Call PyG propagate() method:
191+
# 1. Message phase - compute messages for each edge
192+
# 2. Aggregate phase - aggregate messages for each node
193+
# 3. Update phase - updates only the node features
194+
# Update uses the message from step 1 and any original arguments passed to
195+
# propagate() to update the node embeddings. This is why we need to store
196+
# the updated edge features to return them from the update() method.
190197
x, edge_features = self.propagate(
191198
edge_index=edge_index, x=x, edge_features=edge_features
192199
)
@@ -212,8 +219,8 @@ def message(
212219
"""
213220
# Concat edge features with a final shape of [nedges, latent_dim*3]
214221
edge_features = torch.cat([x_i, x_j, edge_features], dim=-1)
215-
edge_features = self.edge_fn(edge_features)
216-
return edge_features
222+
self._edge_features = self.edge_fn(edge_features) # Create and store
223+
return self._edge_features # This gets passed to aggregate()
217224

218225
def update(
219226
self, x_updated: torch.tensor, x: torch.tensor, edge_features: torch.tensor
@@ -233,9 +240,13 @@ def update(
233240
"""
234241
# Concat node features with a final shape of
235242
# [nparticles, latent_dim (or nnode_in) *2]
243+
# This gets called later, after message() and aggregate()
244+
# Update modified from MessagePassing takes the output of aggregation
245+
# as first argument and any argument which was initially passed to
246+
# propagate hence we need to return the stored value of edge_features
236247
x_updated = torch.cat([x_updated, x], dim=-1)
237248
x_updated = self.node_fn(x_updated)
238-
return x_updated, edge_features
249+
return x_updated, self._edge_features
239250

240251

241252
class Processor(MessagePassing):

test/test_message_edge_features.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from gns.graph_network import *
2+
import torch
3+
from torch_geometric.data import Data
4+
import pytest
5+
6+
7+
@pytest.fixture
8+
def interaction_network_data():
9+
model = InteractionNetwork(
10+
nnode_in=2,
11+
nnode_out=2,
12+
nedge_in=2,
13+
nedge_out=2,
14+
nmlp_layers=2,
15+
mlp_hidden_dim=2,
16+
)
17+
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
18+
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) # node features
19+
edge_attr = torch.tensor([[1, 1], [2, 2]], dtype=torch.float) # edge features
20+
21+
return model, x, edge_index, edge_attr
22+
23+
24+
def test_edge_update(interaction_network_data):
25+
"""Test if edge features are updated and finite and are not simply doubled"""
26+
model, x, edge_index, edge_attr = interaction_network_data
27+
old_edge_attr = edge_attr.clone() # Save the old edge features
28+
29+
# One message passing step
30+
_, updated_edge_attr = model(x=x, edge_index=edge_index, edge_features=edge_attr)
31+
32+
# Check if edge features shape is correct
33+
assert (
34+
edge_attr.shape == old_edge_attr.shape
35+
), f"Edge features shape is not preserved, changed from {old_edge_attr.shape} to {edge_attr.shape}"
36+
# Check if edge features are updated
37+
assert not torch.equal(
38+
updated_edge_attr, old_edge_attr * 2
39+
), "Edge features are simply doubled"
40+
assert not torch.equal(
41+
updated_edge_attr, old_edge_attr
42+
), "Edge features are not updated"
43+
# Check if edge features are finite
44+
assert torch.all(torch.isfinite(edge_attr)), "Edge features are not finite"
45+
46+
47+
def test_gradients_computed(interaction_network_data):
48+
"""Test if gradients are computed and finite"""
49+
model, x, edge_index, edge_attr = interaction_network_data
50+
x.requires_grad = True
51+
edge_attr.requires_grad = True
52+
53+
# First pass
54+
aggr, updated_edge_features = model(
55+
x=x, edge_index=edge_index, edge_features=edge_attr
56+
)
57+
updated_node_features = x + aggr
58+
# Second pass
59+
aggr, updated_edge_features = model(
60+
x=updated_node_features,
61+
edge_index=edge_index,
62+
edge_features=updated_edge_features,
63+
)
64+
updated_node_features = updated_node_features + aggr
65+
# Compute loss
66+
loss = (updated_edge_features).sum()
67+
loss.backward()
68+
69+
# Check if gradients are computed
70+
assert x.grad is not None, "Gradients for node features are not computed"
71+
assert edge_attr.grad is not None, "Gradients for edge features are not computed"
72+
# Check if gradients are finite
73+
assert torch.all(
74+
torch.isfinite(x.grad)
75+
), "Gradients for node features are not finite"
76+
assert torch.all(
77+
torch.isfinite(edge_attr.grad)
78+
), "Gradients for edge features are not finite"

0 commit comments

Comments
 (0)