Skip to content

Commit 6071118

Browse files
akashshah59klane
authored andcommitted
Removing absolute file path
1 parent 268b01d commit 6071118

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

examples/dcrnn/main.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import time
2-
import os
3-
import numpy as np
42

3+
import numpy as np
54
import torch.optim
65

76
from torchts.nn.models.dcrnn import DCRNN
@@ -11,40 +10,53 @@
1110
"cl_decay_steps": 2000,
1211
"filter_type": "dual_random_walk",
1312
"horizon": 12,
14-
"seq_len":12,
13+
"seq_len": 12,
1514
"input_dim": 2,
16-
"max_diffusion_step": 2,
15+
"max_diffusion_step": 2,
1716
"num_layers": 2,
1817
"output_dim": 2,
1918
"use_curriculum_learning": True,
2019
}
2120

22-
optimizer_args = {
23-
'lr':0.01
24-
}
21+
optimizer_args = {"lr": 0.01}
2522

2623
# Code to retrieve the graph in the form of an adjacency matrix.
2724
# This corresponds to the distance between 2 traffic sensors in a traffic network.
2825
# For other applications it can mean anything that defines the adjacency between nodes
2926
# eg. distance between airports of different cities when predicting
3027
# covid infection rate.
3128

32-
graph_pkl_filename = "/home/akash/Desktop/multi-gpu/newest_adj_max.pkl" # Absolute path of graph expected.
29+
graph_pkl_filename = "<Path to graph>"
3330

34-
_,_, adj_mx = utils.load_graph_data(graph_pkl_filename)
31+
_, _, adj_mx = utils.load_graph_data(graph_pkl_filename)
3532

3633
num_units = adj_mx.shape[0]
3734

38-
model_config['num_nodes'] = num_units
35+
model_config["num_nodes"] = num_units
36+
37+
data = np.load(
38+
"<Path to training *.npz file>"
39+
) # Absolute path of train, test, val needed.
3940

40-
data = np.load("/home/akash/Desktop/multi-gpu/train.npz") # Absolute path of train, test, val needed.
4141

4242
def run():
43-
model = DCRNN(adj_mx,num_units,optimizer = torch.optim.SGD,optimizer_args = optimizer_args,**model_config)
43+
model = DCRNN(
44+
adj_mx,
45+
num_units,
46+
optimizer=torch.optim.SGD,
47+
optimizer_args=optimizer_args,
48+
**model_config
49+
)
4450
start = time.time()
45-
model.fit(torch.from_numpy(data['x'].astype('float32')),torch.from_numpy(data['y'].astype('float32')),max_epochs = 10,batch_size = 8)
51+
model.fit(
52+
torch.from_numpy(data["x"].astype("float32")),
53+
torch.from_numpy(data["y"].astype("float32")),
54+
max_epochs=10,
55+
batch_size=8,
56+
)
4657
end = time.time() - start
4758
print("Training time taken %f" % (end - start))
4859

60+
4961
if __name__ == "__main__":
5062
run()

0 commit comments

Comments
 (0)