-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_sac.py
More file actions
81 lines (71 loc) · 2.16 KB
/
test_sac.py
File metadata and controls
81 lines (71 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from matplotlib import pyplot as plt
import random
import numpy as np
import torch
import gymnasium as gym
from policies.network import get_MLP
from utils.replay_buffer import ReplayBuffer
from torch import nn
from policies.sac import SAC
from policies.actor.continuous_actor import ContinuousSoftActor
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--env_seed", type=int, default=42)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
seed = args.seed
env_seed = args.env_seed
print("seed", seed, "; env seed", env_seed)
env = gym.make('Pendulum-v1', max_episode_steps=200)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
env.reset(seed=env_seed)
actor_module = ContinuousSoftActor(state_dim=3,
action_dim=1,
act_bias=np.array([0]),
act_scale=np.array([2]))
critic_module = get_MLP(
num_features=3 + 1,
num_actions=1,
hidden_layers=[128]
)
critic2_module = get_MLP(
num_features=3 + 1,
num_actions=1,
hidden_layers=[128]
)
sac_policy = SAC(
q1=critic_module,
q2=critic2_module,
pi=actor_module,
state_dim=3,
action_dim=1,
lr_q=1e-2,
lr_pi=1e-3,
lr_alpha=1e-2,
auto_alpha=True
)
results_sac = []
buffer = ReplayBuffer(capacity=10000)
minimal_size = 100
batch_size = 64
for epi in range(500):
observation, info = env.reset(seed=0)
terminated = False
truncated = False
epi_len = 0
total_return = 0
while not terminated and not truncated:
action = sac_policy(torch.from_numpy(observation.reshape(1, -1))).detach().squeeze(0).numpy()
prev_obs = observation
observation, reward, terminated, truncated, info = env.step(action)
buffer.add(prev_obs, action, reward, observation, terminated, truncated)
epi_len += 1
total_return += reward
if buffer.size() > minimal_size:
sampled = buffer.sample(batch_size)
sac_policy.update(sampled)
print("epi: {}; len: {}; return: {}".format(epi, epi_len, total_return))
results_sac.append((epi_len, total_return))
env.close()