Description
I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (graph_pooling: graph_token
).
Problem
I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:
import torch
from torch_geometric.data import Batch
# given some data batch, e.g. inside the training loop
# create a copy of the first graph
data = Batch.from_data_list([batch.get_example(0).clone()])
data_p = Batch.from_data_list([batch.get_example(0).clone()])
# and permute the nodes:
# here we simply put the previously last node in first place of the first graph
n = data_p.x.size(0)
p = torch.arange(n, dtype=torch.long) - 1
p[0] = n - 1
data_p.x = data_p.x[p]
assert (data_p.x[0, :] == data.x[-1, :]).all()
assert (data_p.x[1:, :] == data.x[:-1, :]).all()
# make sure to permute the other node features as well
data_p.batch = data_p.batch[p]
data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]
# and change the indices accordingly (all increase by one, just the last one gets set to zero)
n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0
# then get the model outputs for each graph
model.eval()
with torch.no_grad():
output, _ = model(data)
output_p, _ = model(data_p)
# check if outputs are equal
assert torch.allclose(output, output_p), "Permuted graph produces different output!"
This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.
Cause
The cause turned out to be in the add_graph_token
function, in this line:
data.batch, sort_idx = torch.sort(data.batch)
data.x = data.x[sort_idx]
torch.sort
is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.
But it is called without the argument stable
, which means the default stable=False
is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index
, in_degrees
, att_bias
, etc.) are still referencing the old node order and should then also get permuted/ remapped.
Fix
Of course the much simpler solution is to simply use the stable sorting, and change the line to:
data.batch, sort_idx = torch.sort(data.batch, stable=True)
When running the example from above again with this change the outputs are now indeed the same!
I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...