Exploring the behavior of pytorch geometric's LinkNeighborSampler
, we identified the sampling method consistently oversampled the first possible combination of neighbors, while consistently missed the last possible combination of neighbors. This error appears to happen due to the sampling logic implemented here.
How to reproduce
We tested with a toy example of a heterogeneous graph.
from import HeteroData
from torch_geometric.transforms import ToUndirected
import torch
num_papers = 4
num_paper_features = 10
num_authors = 5
num_authors_features = 12
data = HeteroData()
# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['paper'].id = torch.arange(0, num_papers)
data['paper'].type = "paper"
data['author'].x = torch.randn(num_authors, num_authors_features)
data['author'].id = torch.arange(0, num_authors)
data['author'].type = "author"
# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = torch.tensor([
[0, 1, 2, 2, 3, 4, 2, 2],
[0, 0, 0, 1, 0, 1, 2, 3],
# PyTorch tensor functionality:
transform = ToUndirected()
data = transform(data)
With this toy graph, we created a LinkNeighborLoader
that samples two authors
from torch_geometric.loader import LinkNeighborLoader
num_neighbors = {
('author', 'writes', 'paper'): [2],
('paper', 'rev_writes', 'author'): [0],
edge_label_index = (('author', 'writes', 'paper'), torch.tensor([[0],[0]]))
edge_label = torch.tensor([[1, 1]])
loader = LinkNeighborLoader(
Lastly, to determine the sampling frecuency of pairs of neighbors, we sampled 10000 pairs and counted pair frequency
from collections import defaultdict
counts = defaultdict(lambda: 0)
num_samples = 10000
for _ in range(num_samples):
sampled_data = next(iter(loader))
edge = sampled_data[('author', 'writes', 'paper')].edge_index[0].tolist()
author_id = sampled_data['author'].n_id[edge[0]].item()
author2_id = sampled_data['author'].n_id[edge[1]].item()
key = tuple(sorted((author_id, author2_id)))
counts[key] += 1
The output consistently displayed the first combination being sampled two times more than any other, while the last possible combination was never sampled
for k, v in dict(counts).items():
print(k, v/num_samples)
# > (0, 1) 0.3382
# > (1, 2) 0.164
# > (0, 2) 0.1671
# > (0, 3) 0.1706
# > (1, 3) 0.1601
In this case, the combination (0,1)
was oversampled, showing up twice as more than any other combination, while the combination (2,3)
was not sampled at all.
this happened with every number of neighbors and root nodes we selected, as long as the number of sampled neighbors was less than the number of neighbors.