Skip to content

Commit 04a3860

Browse files
committed
v2: Marker interfaces for actors, minor fixes
1 parent 0a7e4ea commit 04a3860

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+173
-136
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ Developers:
183183

184184
* The `Actor` classes have been renamed for clarity:
185185
* `BaseActor` -> `Actor`
186-
* `continuous.ActorProb` -> `ContinuousActorProb`
186+
* `continuous.ActorProb` -> `ContinuousActorProbabilistic`
187187
* `coninuous.Actor` -> `ContinuousActorDeterministic`
188188
* `discrete.Actor` -> `DiscreteActor`
189189
* The `Critic` classes have been renamed for clarity:
@@ -192,7 +192,7 @@ Developers:
192192
* Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`.
193193
* Fix issues pertaining to the torch device assignment of network components (#810):
194194
* Remove 'device' member (and the corresponding constructor argument) from the following classes:
195-
`BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProb`, `ContinuousCritic`,
195+
`BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`,
196196
`DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`,
197197
`IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`,
198198
`RecurrentActorProb`, `RecurrentCritic`, `VAE`

examples/box2d/bipedal_hardcore_sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tianshou.trainer import OffPolicyTrainerParams
1919
from tianshou.utils import TensorboardLogger
2020
from tianshou.utils.net.common import Net
21-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
21+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2222
from tianshou.utils.space_info import SpaceInfo
2323

2424

@@ -111,7 +111,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
111111

112112
# model
113113
net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes)
114-
actor = ContinuousActorProb(
114+
actor = ContinuousActorProbabilistic(
115115
preprocess_net=net_a,
116116
action_shape=args.action_shape,
117117
unbounded=True,

examples/box2d/mcc_sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tianshou.trainer import OffPolicyTrainerParams
1818
from tianshou.utils import TensorboardLogger
1919
from tianshou.utils.net.common import Net
20-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
20+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2121
from tianshou.utils.space_info import SpaceInfo
2222

2323

@@ -69,7 +69,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
6969
test_envs.seed(args.seed)
7070
# model
7171
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes)
72-
actor = ContinuousActorProb(
72+
actor = ContinuousActorProbabilistic(
7373
preprocess_net=net, action_shape=args.action_shape, unbounded=True
7474
).to(args.device)
7575
actor_optim = AdamOptimizerFactory(lr=args.actor_lr)

examples/inverse/irl_gail.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
2626
from tianshou.policy import GAIL
2727
from tianshou.policy.base import Algorithm
28-
from tianshou.policy.modelfree.pg import ActorPolicy
28+
from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic
2929
from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear
3030
from tianshou.trainer import OnPolicyTrainerParams
3131
from tianshou.utils import TensorboardLogger
3232
from tianshou.utils.net.common import Net
33-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
33+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
3434
from tianshou.utils.space_info import SpaceInfo
3535

3636

@@ -127,7 +127,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
127127
hidden_sizes=args.hidden_sizes,
128128
activation=nn.Tanh,
129129
)
130-
actor = ContinuousActorProb(
130+
actor = ContinuousActorProbabilistic(
131131
preprocess_net=net_a,
132132
action_shape=args.action_shape,
133133
unbounded=True,
@@ -204,7 +204,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
204204
)
205205
print("dataset loaded")
206206

207-
policy = ActorPolicy(
207+
policy = ActorPolicyProbabilistic(
208208
actor=actor,
209209
dist_fn=dist,
210210
action_scaling=True,

examples/mujoco/mujoco_a2c.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from tianshou.highlevel.logger import LoggerFactoryDefault
1616
from tianshou.policy import A2C
1717
from tianshou.policy.base import Algorithm
18-
from tianshou.policy.modelfree.pg import ActorPolicy
18+
from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic
1919
from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory
2020
from tianshou.trainer import OnPolicyTrainerParams
2121
from tianshou.utils.net.common import ActorCritic, Net
22-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
22+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2323

2424

2525
def get_args() -> argparse.Namespace:
@@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
9494
hidden_sizes=args.hidden_sizes,
9595
activation=nn.Tanh,
9696
)
97-
actor = ContinuousActorProb(
97+
actor = ContinuousActorProbabilistic(
9898
preprocess_net=net_a,
9999
action_shape=args.action_shape,
100100
unbounded=True,
@@ -140,7 +140,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
140140
loc, scale = loc_scale
141141
return Independent(Normal(loc, scale), 1)
142142

143-
policy = ActorPolicy(
143+
policy = ActorPolicyProbabilistic(
144144
actor=actor,
145145
dist_fn=dist,
146146
action_scaling=True,

examples/mujoco/mujoco_npg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from tianshou.highlevel.logger import LoggerFactoryDefault
1616
from tianshou.policy import NPG
1717
from tianshou.policy.base import Algorithm
18-
from tianshou.policy.modelfree.pg import ActorPolicy
18+
from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic
1919
from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear
2020
from tianshou.trainer import OnPolicyTrainerParams
2121
from tianshou.utils.net.common import Net
22-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
22+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2323

2424

2525
def get_args() -> argparse.Namespace:
@@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
9999
hidden_sizes=args.hidden_sizes,
100100
activation=nn.Tanh,
101101
)
102-
actor = ContinuousActorProb(
102+
actor = ContinuousActorProbabilistic(
103103
preprocess_net=net_a,
104104
action_shape=args.action_shape,
105105
unbounded=True,
@@ -138,7 +138,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
138138
loc, scale = loc_scale
139139
return Independent(Normal(loc, scale), 1)
140140

141-
policy = ActorPolicy(
141+
policy = ActorPolicyProbabilistic(
142142
actor=actor,
143143
dist_fn=dist,
144144
action_scaling=True,

examples/mujoco/mujoco_ppo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from tianshou.highlevel.logger import LoggerFactoryDefault
1616
from tianshou.policy import PPO
1717
from tianshou.policy.base import Algorithm
18-
from tianshou.policy.modelfree.pg import ActorPolicy
18+
from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic
1919
from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear
2020
from tianshou.trainer import OnPolicyTrainerParams
2121
from tianshou.utils.net.common import ActorCritic, Net
22-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
22+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2323

2424

2525
def get_args() -> argparse.Namespace:
@@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
9999
hidden_sizes=args.hidden_sizes,
100100
activation=nn.Tanh,
101101
)
102-
actor = ContinuousActorProb(
102+
actor = ContinuousActorProbabilistic(
103103
preprocess_net=net_a,
104104
action_shape=args.action_shape,
105105
unbounded=True,
@@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
141141
loc, scale = loc_scale
142142
return Independent(Normal(loc, scale), 1)
143143

144-
policy = ActorPolicy(
144+
policy = ActorPolicyProbabilistic(
145145
actor=actor,
146146
dist_fn=dist,
147147
action_scaling=True,

examples/mujoco/mujoco_redq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tianshou.policy.optim import AdamOptimizerFactory
1919
from tianshou.trainer import OffPolicyTrainerParams
2020
from tianshou.utils.net.common import EnsembleLinear, Net
21-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
21+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2222

2323

2424
def get_args() -> argparse.Namespace:
@@ -90,7 +90,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
9090
torch.manual_seed(args.seed)
9191
# model
9292
net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes)
93-
actor = ContinuousActorProb(
93+
actor = ContinuousActorProbabilistic(
9494
preprocess_net=net_a,
9595
action_shape=args.action_shape,
9696
unbounded=True,

examples/mujoco/mujoco_reinforce.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from tianshou.highlevel.logger import LoggerFactoryDefault
1616
from tianshou.policy import Reinforce
1717
from tianshou.policy.base import Algorithm
18-
from tianshou.policy.modelfree.pg import ActorPolicy
18+
from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic
1919
from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear
2020
from tianshou.trainer import OnPolicyTrainerParams
2121
from tianshou.utils.net.common import Net
22-
from tianshou.utils.net.continuous import ContinuousActorProb
22+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic
2323

2424

2525
def get_args() -> argparse.Namespace:
@@ -91,7 +91,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
9191
hidden_sizes=args.hidden_sizes,
9292
activation=nn.Tanh,
9393
)
94-
actor = ContinuousActorProb(
94+
actor = ContinuousActorProbabilistic(
9595
preprocess_net=net_a,
9696
action_shape=args.action_shape,
9797
unbounded=True,
@@ -124,7 +124,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
124124
loc, scale = loc_scale
125125
return Independent(Normal(loc, scale), 1)
126126

127-
policy = ActorPolicy(
127+
policy = ActorPolicyProbabilistic(
128128
actor=actor,
129129
dist_fn=dist,
130130
action_space=env.action_space,

examples/mujoco/mujoco_sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tianshou.policy.optim import AdamOptimizerFactory
1818
from tianshou.trainer import OffPolicyTrainerParams
1919
from tianshou.utils.net.common import Net
20-
from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic
20+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2121

2222

2323
def get_args() -> argparse.Namespace:
@@ -86,7 +86,7 @@ def main(args: argparse.Namespace = get_args()) -> None:
8686
torch.manual_seed(args.seed)
8787
# model
8888
net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes)
89-
actor = ContinuousActorProb(
89+
actor = ContinuousActorProbabilistic(
9090
preprocess_net=net_a,
9191
action_shape=args.action_shape,
9292
unbounded=True,

0 commit comments

Comments
 (0)