Skip to content

Commit 268b01d

Browse files
akashshah59klane
authored andcommitted
Correcting example to newer version
1 parent 717afc1 commit 268b01d

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

examples/dcrnn/main.py

+20-27
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,50 @@
11
import time
2+
import os
3+
import numpy as np
24

3-
from pytorch_lightning import Trainer
5+
import torch.optim
46

57
from torchts.nn.models.dcrnn import DCRNN
68
from torchts.utils import data as utils
79

8-
data_config = {
9-
"batch_size": 8,
10-
"dataset_dir": ".", # Absolute path to train , test, val expected
11-
"test_batch_size": 8,
12-
"val_batch_size": 8,
13-
"graph_pkl_filename": "adjacency_matrix.pkl",
14-
# Absolute path to graph file expected
15-
}
16-
1710
model_config = {
1811
"cl_decay_steps": 2000,
1912
"filter_type": "dual_random_walk",
2013
"horizon": 12,
14+
"seq_len":12,
2115
"input_dim": 2,
22-
"l1_decay": 0,
23-
"max_diffusion_step": 2,
24-
"num_nodes": 320,
25-
"num_rnn_layers": 2,
26-
"output_dim": 1,
27-
"rnn_units": 128,
28-
"seq_len": 12,
29-
"use_curriculum_learning": "true",
16+
"max_diffusion_step": 2,
17+
"num_layers": 2,
18+
"output_dim": 2,
19+
"use_curriculum_learning": True,
3020
}
3121

22+
optimizer_args = {
23+
'lr':0.01
24+
}
25+
3226
# Code to retrieve the graph in the form of an adjacency matrix.
3327
# This corresponds to the distance between 2 traffic sensors in a traffic network.
3428
# For other applications it can mean anything that defines the adjacency between nodes
3529
# eg. distance between airports of different cities when predicting
3630
# covid infection rate.
3731

38-
graph_pkl_filename = data_config["graph_pkl_filename"]
39-
sensor_ids, sensor_id_to_ind, adj_mx = utils.load_graph_data(graph_pkl_filename)
32+
graph_pkl_filename = "/home/akash/Desktop/multi-gpu/newest_adj_max.pkl" # Absolute path of graph expected.
4033

41-
data = utils.load_dataset(**data_config)
42-
scaler = data["scaler"]
34+
_,_, adj_mx = utils.load_graph_data(graph_pkl_filename)
4335

44-
model = DCRNN(adj_mx, scaler, **model_config)
36+
num_units = adj_mx.shape[0]
4537

38+
model_config['num_nodes'] = num_units
4639

47-
def run():
48-
trainer = Trainer(max_epochs=10, logger=True)
40+
data = np.load("/home/akash/Desktop/multi-gpu/train.npz") # Absolute path of train, test, val needed.
4941

42+
def run():
43+
model = DCRNN(adj_mx,num_units,optimizer = torch.optim.SGD,optimizer_args = optimizer_args,**model_config)
5044
start = time.time()
51-
trainer.fit(model, data["train_loader"], data["val_loader"])
45+
model.fit(torch.from_numpy(data['x'].astype('float32')),torch.from_numpy(data['y'].astype('float32')),max_epochs = 10,batch_size = 8)
5246
end = time.time() - start
5347
print("Training time taken %f" % (end - start))
5448

55-
5649
if __name__ == "__main__":
5750
run()

0 commit comments

Comments
 (0)