forked from dongspam0209/traffic_light_control_DQN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining_main.py
More file actions
110 lines (91 loc) · 3.04 KB
/
training_main.py
File metadata and controls
110 lines (91 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from generator import CarGenerator
from util import set_sumo , set_train_path
from simulation import Simulation
import matplotlib.pyplot as plt
from Model import DQN
import traci
from replay import ReplayMemory
import wandb
from visualization import Visualization
################################################################
total_episode=500
n_cars_generated=1000
num_states=(3,16,100)
num_actions=4
yellow_duration=3
green_duration=8
green_turn_duration=4
memory_capacity=1000
wandb.init(
# set the wandb project where this run will be logged
project="tlsc-project",
# track hyperparameters and run metadata
config={
"learning_rate": 0.0001,
"architecture": "DQN",
"dataset": "CIFAR-100",
"epochs": 10000,
}
)
################################################################
if __name__ == "__main__":
#################################################################
sumocfg_file_name = "cross.sumocfg"
gui = False # Change to False if you don't want the GUI
max_steps=3600
sumo_cmd = set_sumo(gui, sumocfg_file_name, max_steps)
path=set_train_path('plot')
##################################################################
ReplayMemory=ReplayMemory(
memory_capacity
)
CarGenerator=CarGenerator(
max_steps,
n_cars_generated
)
Simulation=Simulation(
DQN,
ReplayMemory,
CarGenerator,
sumo_cmd,
max_steps,
num_states,
num_actions,
green_duration,
yellow_duration,
green_turn_duration,
)
DQN=DQN(
num_states,
num_actions
)
Visualization=Visualization(
path,
dpi=96
)
episode=0
# epsilon=1.0
# min_epsilon=0.1
# decay_rate=0.99
while episode < total_episode:
print(f'episode {episode}')
# epsilon=max(min_epsilon,epsilon*decay_rate)
epsilon=1-(episode/total_episode)
Simulation.run(episode,epsilon)
print(f'queue length in epsiode {episode}',Simulation.queue_length_store[episode])
print(f'loss in epsiode {episode}',Simulation.loss_store[episode])
print(f'wait time in epsiode {episode}',Simulation.wait_time_store[episode])
print(f'reward in epsiode {episode}',Simulation.reward_store[episode])
# wandb
wandb.log({
"episode": episode,
"epsilon": epsilon,
"queue length": Simulation.queue_length_store[episode],
"loss": Simulation.loss_store[episode],
"wait time": Simulation.wait_time_store[episode],
"reward": Simulation.reward_store[episode]
})
episode += 1
Visualization.save_data_and_plot(data=Simulation.queue_length_store,filename='queue',xlabel='Episode',ylabel='queue length')
Visualization.save_data_and_plot(data=Simulation.loss_store,filename='loss',xlabel='Episode',ylabel='loss')
Visualization.save_data_and_plot(data=Simulation.wait_time_store,filename='wait_time_episode',xlabel='Episode',ylabel='wait_time')