-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
75 lines (67 loc) · 2.33 KB
/
train.py
File metadata and controls
75 lines (67 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from agent import DiscreteEpsilonGreedyAgent
from matplotlib import pyplot as plt
from tetris_gymnasium.envs import Tetris
from torch.utils.tensorboard import SummaryWriter
from typing import Callable
import gymnasium as gym
import numpy as np
import secrets
import sys
import time
import Utils
debugParameters = {
# "renderMode": "human",
"renderMode": None,
"numEpisodes": 100_000,
"numTotalSteps": 10_000_000, # > epsilonDecaySteps + learningStartPoint
"plotRollingLength": 1000,
}
hyperParameters = {
"epsilonStart": 1,
"epsilonEnd": 0.1,
"epsilonDecaySteps": 4_000_000,
"learningRate": 0.001,
"discountFactor": 0.99,
"replayBufferCapacity": 1_000_000,
"batchTransitionSampleSize": 512,
"trainingFrequency": 4,
"targetNetworkUpdateFrequency": 10_000,
"checkpointRate": 100_000,
"learningStartPoint": 1_000_000,
}
def main():
seed = secrets.randbits(32) # numpy seed needs to be between 0 and 2**32 - 1
modelPath = None
if (len(sys.argv) > 1):
modelPath = sys.argv[1]
writer = SummaryWriter(comment='_train', purge_step=10_000, max_queue=1_000)
#writer.add_hparams(hyperParameters)
env: Tetris = gym.make(
"tetris_gymnasium/Tetris",
render_mode=debugParameters["renderMode"])
env.reset(seed=seed)
# Randomly sample from set of legal actions
randomActionFn: Callable[[], int] = lambda: env.action_space.sample()
agent = DiscreteEpsilonGreedyAgent(
seed=seed,
numActions=env.action_space.n,
randomActionFn=randomActionFn,
writer=writer,
train=True,
modelPath=modelPath,
hyperParameters=hyperParameters)
totalStepsList, totalRewardList = Utils.runBatchEpisodes(
env,
agent,
writer,
debugParameters["numEpisodes"],
debugParameters["numTotalSteps"],
train=True)
agent.QFunction.close()
env.close()
writer.close()
# Visualization
pairs = [("Episode lengths", totalStepsList), ("Episode rewards",totalRewardList)]
Utils.plotBatchResults(pairs, debugParameters["plotRollingLength"])
if __name__ == "__main__":
main()