Skip to content

Webpage examples do not use temporal encodings #307

@pauvilasoler

Description

@pauvilasoler

Hello!

I guess this is not a super important issue but I feel that maybe it's good to give a heads up as I think it would be nicer to show examples that actually use temporal encodings as I am guessing that is what most people are interested in given the purposes of the library.

Basically, the examples that are given in the webpage are using a recurrent architecture without actually using the hidden state (see below). In other words, they are merely predicting at each snapshot sequentially using only the information at that snapshot, which is something that you don't need PyG Temporal for as it could easily be done with PyG alone.

Image Image

A quick workaround could be:

from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader = WikiMathsDatasetLoader()

dataset = loader.get_dataset(lags=14)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)

import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, filters):
        super(RecurrentGCN, self).__init__()
        self.recurrent = GConvGRU(node_features, filters, 2)
        self.linear = torch.nn.Linear(filters, 1)

    def forward(self, x, edge_index, edge_weight, h):
        h_ = self.recurrent(x, edge_index, edge_weight, H=h)
        return h_

from tqdm import tqdm

model = RecurrentGCN(node_features=14, filters=32)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

h = None  # Initialize hidden state

for epoch in range(50):
    for time, snapshot in enumerate(train_dataset):
        # Forward pass
        h = model.recurrent(snapshot.x, snapshot.edge_index, snapshot.edge_attr, H=h)
        y_hat = F.relu(h)
        y_hat = model.linear(y_hat)
        cost = torch.mean((y_hat.squeeze() - snapshot.y)**2)
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()
        h = h.detach()

model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    h = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h)
    y_hat = F.relu(h)
    y_hat = model.linear(y_hat)
    cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))

The MSE error goes up a tiny bit with the temporal information but I feel it would be a more faithful example of what can be done with this library.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions