Skip to content

Commit 0fa36cd

Browse files
committed
v2: Better handling of max_action in actor
1 parent ba32173 commit 0fa36cd

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

test/continuous/test_redq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T
6464
space_info = SpaceInfo.from_env(env)
6565
args.state_shape = space_info.observation_info.obs_shape
6666
args.action_shape = space_info.action_info.action_shape
67-
args.max_action = space_info.action_info.max_action
6867
if args.reward_threshold is None:
6968
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
7069
args.reward_threshold = default_reward_threshold.get(

test/offline/gather_pendulum_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def gather_data() -> VectorReplayBuffer:
7373
space_info = SpaceInfo.from_env(env)
7474
args.state_shape = space_info.observation_info.obs_shape
7575
args.action_shape = space_info.action_info.action_shape
76-
args.max_action = space_info.action_info.max_action
7776

7877
if args.reward_threshold is None:
7978
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}

test/offline/test_bcq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr
104104
output_dim=args.action_dim,
105105
hidden_sizes=args.hidden_sizes,
106106
)
107-
actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to(
107+
actor_perturbation = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to(
108108
args.device,
109109
)
110110
actor_optim = AdamOptimizerFactory(lr=args.actor_lr)
@@ -141,7 +141,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr
141141
vae_optim = AdamOptimizerFactory()
142142

143143
policy = BCQPolicy(
144-
actor_perturbation=actor,
144+
actor_perturbation=actor_perturbation,
145145
critic=critic,
146146
vae=vae,
147147
action_space=env.action_space,

tianshou/algorithm/modelfree/reinforce.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,20 @@ def __init__(
148148
action_scaling=action_scaling,
149149
action_bound_method=action_bound_method,
150150
)
151-
if action_scaling and not np.isclose(actor.max_action, 1.0):
152-
warnings.warn(
153-
"action_scaling and action_bound_method are only intended "
154-
"to deal with unbounded model action space, but find actor model "
155-
f"bound action space with max_action={actor.max_action}. "
156-
"Consider using unbounded=True option of the actor model, "
157-
"or set action_scaling to False and action_bound_method to None.",
158-
)
151+
if action_scaling:
152+
try:
153+
max_action = float(actor.max_action) # type: ignore
154+
if np.isclose(max_action, 1.0):
155+
warnings.warn(
156+
"action_scaling and action_bound_method are only intended "
157+
"to deal with unbounded model action space, but find actor model "
158+
f"bound action space with max_action={actor.max_action}. "
159+
"Consider using unbounded=True option of the actor model, "
160+
"or set action_scaling to False and action_bound_method to None.",
161+
)
162+
except:
163+
pass
164+
159165
self.actor = actor
160166
self.dist_fn = dist_fn
161167
self._eps = 1e-8
@@ -286,7 +292,7 @@ def add_discounted_returns(
286292
should be marked by done flag, unfinished (or collecting) episodes will be
287293
recognized by buffer.unfinished_index().
288294
:param buffer: the corresponding replay buffer.
289-
:param numpy.ndarray indices: tell batch's location in buffer, batch is equal
295+
:param indices: tell batch's location in buffer, batch is equal
290296
to buffer[indices].
291297
"""
292298
v_s_ = np.full(indices.shape, self.ret_rms.mean)
@@ -306,8 +312,7 @@ def add_discounted_returns(
306312
self.ret_rms.update(unnormalized_returns)
307313
else:
308314
batch.returns = unnormalized_returns
309-
batch: BatchWithReturnsProtocol
310-
return batch
315+
return cast(BatchWithReturnsProtocol, batch)
311316

312317

313318
class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]):
@@ -316,7 +321,7 @@ class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]):
316321
def __init__(
317322
self,
318323
*,
319-
policy: TActorPolicy,
324+
policy: ActorPolicyProbabilistic,
320325
gamma: float = 0.99,
321326
return_standardization: bool = False,
322327
optim: OptimizerFactory,

0 commit comments

Comments
 (0)