|
1 | 1 | import time
|
| 2 | +import os |
| 3 | +import numpy as np |
2 | 4 |
|
3 |
| -from pytorch_lightning import Trainer |
| 5 | +import torch.optim |
4 | 6 |
|
5 | 7 | from torchts.nn.models.dcrnn import DCRNN
|
6 | 8 | from torchts.utils import data as utils
|
7 | 9 |
|
8 |
| -data_config = { |
9 |
| - "batch_size": 8, |
10 |
| - "dataset_dir": ".", # Absolute path to train , test, val expected |
11 |
| - "test_batch_size": 8, |
12 |
| - "val_batch_size": 8, |
13 |
| - "graph_pkl_filename": "adjacency_matrix.pkl", |
14 |
| - # Absolute path to graph file expected |
15 |
| -} |
16 |
| - |
17 | 10 | model_config = {
|
18 | 11 | "cl_decay_steps": 2000,
|
19 | 12 | "filter_type": "dual_random_walk",
|
20 | 13 | "horizon": 12,
|
| 14 | + "seq_len":12, |
21 | 15 | "input_dim": 2,
|
22 |
| - "l1_decay": 0, |
23 |
| - "max_diffusion_step": 2, |
24 |
| - "num_nodes": 320, |
25 |
| - "num_rnn_layers": 2, |
26 |
| - "output_dim": 1, |
27 |
| - "rnn_units": 128, |
28 |
| - "seq_len": 12, |
29 |
| - "use_curriculum_learning": "true", |
| 16 | + "max_diffusion_step": 2, |
| 17 | + "num_layers": 2, |
| 18 | + "output_dim": 2, |
| 19 | + "use_curriculum_learning": True, |
30 | 20 | }
|
31 | 21 |
|
| 22 | +optimizer_args = { |
| 23 | + 'lr':0.01 |
| 24 | + } |
| 25 | + |
32 | 26 | # Code to retrieve the graph in the form of an adjacency matrix.
|
33 | 27 | # This corresponds to the distance between 2 traffic sensors in a traffic network.
|
34 | 28 | # For other applications it can mean anything that defines the adjacency between nodes
|
35 | 29 | # eg. distance between airports of different cities when predicting
|
36 | 30 | # covid infection rate.
|
37 | 31 |
|
38 |
| -graph_pkl_filename = data_config["graph_pkl_filename"] |
39 |
| -sensor_ids, sensor_id_to_ind, adj_mx = utils.load_graph_data(graph_pkl_filename) |
| 32 | +graph_pkl_filename = "/home/akash/Desktop/multi-gpu/newest_adj_max.pkl" # Absolute path of graph expected. |
40 | 33 |
|
41 |
| -data = utils.load_dataset(**data_config) |
42 |
| -scaler = data["scaler"] |
| 34 | +_,_, adj_mx = utils.load_graph_data(graph_pkl_filename) |
43 | 35 |
|
44 |
| -model = DCRNN(adj_mx, scaler, **model_config) |
| 36 | +num_units = adj_mx.shape[0] |
45 | 37 |
|
| 38 | +model_config['num_nodes'] = num_units |
46 | 39 |
|
47 |
| -def run(): |
48 |
| - trainer = Trainer(max_epochs=10, logger=True) |
| 40 | +data = np.load("/home/akash/Desktop/multi-gpu/train.npz") # Absolute path of train, test, val needed. |
49 | 41 |
|
| 42 | +def run(): |
| 43 | + model = DCRNN(adj_mx,num_units,optimizer = torch.optim.SGD,optimizer_args = optimizer_args,**model_config) |
50 | 44 | start = time.time()
|
51 |
| - trainer.fit(model, data["train_loader"], data["val_loader"]) |
| 45 | + model.fit(torch.from_numpy(data['x'].astype('float32')),torch.from_numpy(data['y'].astype('float32')),max_epochs = 10,batch_size = 8) |
52 | 46 | end = time.time() - start
|
53 | 47 | print("Training time taken %f" % (end - start))
|
54 | 48 |
|
55 |
| - |
56 | 49 | if __name__ == "__main__":
|
57 | 50 | run()
|
0 commit comments