Hello!
I guess this is not a super important issue but I feel that maybe it's good to give a heads up as I think it would be nicer to show examples that actually use temporal encodings as I am guessing that is what most people are interested in given the purposes of the library.
Basically, the examples that are given in the webpage are using a recurrent architecture without actually using the hidden state (see below). In other words, they are merely predicting at each snapshot sequentially using only the information at that snapshot, which is something that you don't need PyG Temporal for as it could easily be done with PyG alone.
A quick workaround could be:
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
loader = WikiMathsDatasetLoader()
dataset = loader.get_dataset(lags=14)
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.5)
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, filters):
super(RecurrentGCN, self).__init__()
self.recurrent = GConvGRU(node_features, filters, 2)
self.linear = torch.nn.Linear(filters, 1)
def forward(self, x, edge_index, edge_weight, h):
h_ = self.recurrent(x, edge_index, edge_weight, H=h)
return h_
from tqdm import tqdm
model = RecurrentGCN(node_features=14, filters=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
h = None # Initialize hidden state
for epoch in range(50):
for time, snapshot in enumerate(train_dataset):
# Forward pass
h = model.recurrent(snapshot.x, snapshot.edge_index, snapshot.edge_attr, H=h)
y_hat = F.relu(h)
y_hat = model.linear(y_hat)
cost = torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost.backward()
optimizer.step()
optimizer.zero_grad()
h = h.detach()
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
h = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr, h)
y_hat = F.relu(h)
y_hat = model.linear(y_hat)
cost = cost + torch.mean((y_hat.squeeze() - snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {:.4f}".format(cost))
The MSE error goes up a tiny bit with the temporal information but I feel it would be a more faithful example of what can be done with this library.
Thanks!
Hello!
I guess this is not a super important issue but I feel that maybe it's good to give a heads up as I think it would be nicer to show examples that actually use temporal encodings as I am guessing that is what most people are interested in given the purposes of the library.
Basically, the examples that are given in the webpage are using a recurrent architecture without actually using the hidden state (see below). In other words, they are merely predicting at each snapshot sequentially using only the information at that snapshot, which is something that you don't need PyG Temporal for as it could easily be done with PyG alone.
A quick workaround could be:
The MSE error goes up a tiny bit with the temporal information but I feel it would be a more faithful example of what can be done with this library.
Thanks!