-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathddpg.py
More file actions
38 lines (29 loc) · 1.19 KB
/
ddpg.py
File metadata and controls
38 lines (29 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#%% train ddpg
import os
import gym
import numpy as np
from stable_baselines import DDPG
from stable_baselines.ddpg.policies import MlpPolicy
from stable_baselines.ddpg.policies import LnMlpPolicy
from stable_baselines.common.noise import NormalActionNoise
from stable_baselines.bench import Monitor
from stable_baselines.common.callbacks import BaseCallback
import Arm2DEnv as ae
from utils import SaveOnBestTrainingRewardCallback, snap_code
#%% sac
log_dir_root = './sandbox/ddpg/'
os.makedirs(log_dir_root, exist_ok=True)
# create a snapshot of code
log_dir = snap_code(log_dir_root)
env = ae.ArmModel()
env.mode = 'train'
env = Monitor(env, log_dir)
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=float(0.1) * np.ones(n_actions))
model = DDPG(LnMlpPolicy, env, verbose=1, gamma=0.98, tau=0.01,
actor_lr=0.0001, critic_lr=0.001, action_noise=action_noise,
buffer_size=int(5E6), batch_size=128, random_exploration=0.001)
callback = SaveOnBestTrainingRewardCallback(check_freq=int(5E4), log_dir=log_dir)
model.learn(total_timesteps=int(6E6), callback=callback)#(total_timesteps=int(4E5))
# %%
model.save("twolink_arm_ddpg")