Skip to content

Commit dd4a011

Browse files
authored
Fix SAC loss explode (#333)
* change SAC action_bound_method to "clip" (tanh is hardcoded in forward) * docstring update * modelbase -> modelbased
1 parent 825da9b commit dd4a011

File tree

14 files changed

+81
-55
lines changed

14 files changed

+81
-55
lines changed

examples/modelbase/psrl.py

Lines changed: 0 additions & 1 deletion
This file was deleted.
File renamed without changes.

examples/modelbased/psrl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../test/modelbased/test_psrl.py

examples/mujoco/run_experiments.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ TASK=$1
66
echo "Experiments started."
77
for seed in $(seq 0 9)
88
do
9-
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
9+
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 &
1010
done
1111
echo "Experiments ended."

test/continuous/test_sac_with_il.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ def get_args():
2020
parser.add_argument('--task', type=str, default='Pendulum-v0')
2121
parser.add_argument('--seed', type=int, default=0)
2222
parser.add_argument('--buffer-size', type=int, default=20000)
23-
parser.add_argument('--actor-lr', type=float, default=3e-4)
23+
parser.add_argument('--actor-lr', type=float, default=1e-3)
2424
parser.add_argument('--critic-lr', type=float, default=1e-3)
2525
parser.add_argument('--il-lr', type=float, default=1e-3)
2626
parser.add_argument('--gamma', type=float, default=0.99)
2727
parser.add_argument('--tau', type=float, default=0.005)
2828
parser.add_argument('--alpha', type=float, default=0.2)
29+
parser.add_argument('--auto-alpha', type=int, default=1)
30+
parser.add_argument('--alpha-lr', type=float, default=3e-4)
2931
parser.add_argument('--epoch', type=int, default=5)
3032
parser.add_argument('--step-per-epoch', type=int, default=24000)
3133
parser.add_argument('--il-step-per-epoch', type=int, default=500)
@@ -41,7 +43,7 @@ def get_args():
4143
parser.add_argument('--logdir', type=str, default='log')
4244
parser.add_argument('--render', type=float, default=0.)
4345
parser.add_argument('--rew-norm', action="store_true", default=False)
44-
parser.add_argument('--n-step', type=int, default=4)
46+
parser.add_argument('--n-step', type=int, default=3)
4547
parser.add_argument(
4648
'--device', type=str,
4749
default='cuda' if torch.cuda.is_available() else 'cpu')
@@ -85,6 +87,13 @@ def test_sac_with_il(args=get_args()):
8587
concat=True, device=args.device)
8688
critic2 = Critic(net_c2, device=args.device).to(args.device)
8789
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
90+
91+
if args.auto_alpha:
92+
target_entropy = -np.prod(env.action_space.shape)
93+
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
94+
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
95+
args.alpha = (target_entropy, log_alpha, alpha_optim)
96+
8897
policy = SACPolicy(
8998
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
9099
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
@@ -135,18 +144,20 @@ def stop_fn(mean_rewards):
135144
args.action_shape, max_action=args.max_action, device=args.device
136145
).to(args.device)
137146
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
138-
il_policy = ImitationPolicy(net, optim, mode='continuous')
147+
il_policy = ImitationPolicy(
148+
net, optim, mode='continuous', action_space=env.action_space,
149+
action_scaling=True, action_bound_method="clip")
139150
il_test_collector = Collector(
140151
il_policy,
141-
DummyVectorEnv(
142-
[lambda: gym.make(args.task) for _ in range(args.test_num)])
152+
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
143153
)
144154
train_collector.reset()
145155
result = offpolicy_trainer(
146156
il_policy, train_collector, il_test_collector, args.epoch,
147157
args.il_step_per_epoch, args.step_per_collect, args.test_num,
148158
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
149159
assert stop_fn(result['best_reward'])
160+
150161
if __name__ == '__main__':
151162
pprint.pprint(result)
152163
# Let's watch its performance!

tianshou/policy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
1313
from tianshou.policy.imitation.base import ImitationPolicy
1414
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
15-
from tianshou.policy.modelbase.psrl import PSRLPolicy
15+
from tianshou.policy.modelbased.psrl import PSRLPolicy
1616
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
1717

1818

tianshou/policy/base.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,44 @@
1212
class BasePolicy(ABC, nn.Module):
1313
"""The base class for any RL policy.
1414
15-
Tianshou aims to modularizing RL algorithms. It comes into several classes
16-
of policies in Tianshou. All of the policy classes must inherit
15+
Tianshou aims to modularizing RL algorithms. It comes into several classes of
16+
policies in Tianshou. All of the policy classes must inherit
1717
:class:`~tianshou.policy.BasePolicy`.
1818
19-
A policy class typically has four parts:
19+
A policy class typically has the following parts:
2020
21-
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
22-
including coping the target network and so on;
21+
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \
22+
coping the target network and so on;
2323
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
2424
observation;
25-
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
26-
the replay buffer (this function can interact with replay buffer);
27-
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
28-
batch of data.
25+
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \
26+
replay buffer (this function can interact with replay buffer);
27+
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \
28+
data.
29+
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \
30+
from the learning process (e.g., prioritized replay buffer needs to update \
31+
the weight);
32+
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \
33+
i.e., `process_fn -> learn -> post_process_fn`.
2934
3035
Most of the policy needs a neural network to predict the action and an
3136
optimizer to optimize the policy. The rules of self-defined networks are:
3237
33-
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
34-
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
35-
usage), and other information "info" provided by the environment.
36-
2. Output: some "logits", the next hidden state "state", and the \
37-
intermediate result during policy forwarding procedure "policy". The \
38-
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
39-
the policy process the network output. For example, in PPO, the return of \
40-
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
41-
"policy" can be a Batch of torch.Tensor or other things, which will be \
42-
stored in the replay buffer, and can be accessed in the policy update \
43-
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
44-
45-
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
46-
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
47-
``torch.nn.Module``, for instance, loading and saving the model:
38+
1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \
39+
dict or any others), hidden state "state" (for RNN usage), and other information \
40+
"info" provided by the environment.
41+
2. Output: some "logits", the next hidden state "state", and the intermediate \
42+
result during policy forwarding procedure "policy". The "logits" could be a tuple \
43+
instead of a ``torch.Tensor``. It depends on how the policy process the network \
44+
output. For example, in PPO, the return of the network might be \
45+
``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \
46+
torch.Tensor or other things, which will be stored in the replay buffer, and can \
47+
be accessed in the policy update process (e.g. in "policy.learn()", the \
48+
"batch.policy" is what you need).
49+
50+
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can
51+
use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``,
52+
for instance, loading and saving the model:
4853
::
4954
5055
torch.save(policy.state_dict(), "policy.pth")
@@ -117,6 +122,15 @@ def forward(
117122
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
118123
# and in the sampled data batch, you can directly use
119124
# batch.policy.log_prob to get your data.
125+
126+
.. note::
127+
128+
In continuous action space, you should do another step "map_action" to get
129+
the real action:
130+
::
131+
132+
act = policy(batch).act # doesn't map to the target action range
133+
act = policy.map_action(act, batch)
120134
"""
121135
pass
122136

0 commit comments

Comments
 (0)