Skip to content

Neighbor Sampling without replacement doesn't sample uniformly #385

Open
@aristizabal95

Description

@aristizabal95

Exploring the behavior of pytorch geometric's LinkNeighborSampler, we identified the sampling method consistently oversampled the first possible combination of neighbors, while consistently missed the last possible combination of neighbors. This error appears to happen due to the sampling logic implemented here.

How to reproduce

We tested with a toy example of a heterogeneous graph.

from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
import torch

num_papers = 4
num_paper_features = 10
num_authors = 5
num_authors_features = 12
data = HeteroData()

# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['paper'].id = torch.arange(0, num_papers)
data['paper'].type = "paper"
data['author'].x = torch.randn(num_authors, num_authors_features)
data['author'].id = torch.arange(0, num_authors)
data['author'].type = "author"

# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = torch.tensor([
    [0, 1, 2, 2, 3, 4, 2, 2],
    [0, 0, 0, 1, 0, 1, 2, 3],
])


# PyTorch tensor functionality:
transform = ToUndirected()
data = transform(data)

With this toy graph, we created a LinkNeighborLoader that samples two authors

from torch_geometric.loader import LinkNeighborLoader

num_neighbors = {
    ('author', 'writes', 'paper'): [2],
    ('paper', 'rev_writes', 'author'): [0],
}

edge_label_index = (('author', 'writes', 'paper'), torch.tensor([[0],[0]]))
edge_label = torch.tensor([[1, 1]])

loader = LinkNeighborLoader(
    data,
    num_neighbors=num_neighbors,
    edge_label_index=edge_label_index,
    edge_label=edge_label,
)

Lastly, to determine the sampling frecuency of pairs of neighbors, we sampled 10000 pairs and counted pair frequency

from collections import defaultdict

counts = defaultdict(lambda: 0)
num_samples = 10000

for _ in range(num_samples):

    sampled_data = next(iter(loader))
    edge = sampled_data[('author', 'writes', 'paper')].edge_index[0].tolist()
    author_id = sampled_data['author'].n_id[edge[0]].item()
    author2_id = sampled_data['author'].n_id[edge[1]].item()
    key = tuple(sorted((author_id, author2_id)))
    counts[key] += 1

The output consistently displayed the first combination being sampled two times more than any other, while the last possible combination was never sampled

for k, v in dict(counts).items():
    print(k, v/num_samples)
# > (0, 1) 0.3382
# > (1, 2) 0.164
# > (0, 2) 0.1671
# > (0, 3) 0.1706
# > (1, 3) 0.1601

In this case, the combination (0,1) was oversampled, showing up twice as more than any other combination, while the combination (2,3) was not sampled at all.

this happened with every number of neighbors and root nodes we selected, as long as the number of sampled neighbors was less than the number of neighbors.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions