Skip to content

Commit 0312351

Browse files
committed
v2: add_exploration_noise - raise error on wrong type instead of doing nothing
1 parent 2972b13 commit 0312351

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

tianshou/algorithm/modelfree/bdqn.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, cast
1+
from typing import cast
22

33
import gymnasium as gym
44
import numpy as np
@@ -65,7 +65,6 @@ def forward(
6565
batch: ObsBatchProtocol,
6666
state: dict | BatchProtocol | np.ndarray | None = None,
6767
model: torch.nn.Module | None = None,
68-
**kwargs: Any,
6968
) -> ModelOutputBatchProtocol:
7069
if model is None:
7170
model = self.model
@@ -84,8 +83,9 @@ def add_exploration_noise(
8483
batch: ObsBatchProtocol,
8584
) -> TArrOrActBatch:
8685
eps = self.eps_training if self.is_within_training_step else self.eps_inference
87-
# TODO: This looks problematic; the non-array case is silently ignored
88-
if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0):
86+
if not np.isclose(eps, 0.0):
87+
return act
88+
if isinstance(act, np.ndarray):
8989
bsz = len(act)
9090
rand_mask = np.random.rand(bsz) < eps
9191
rand_act = np.random.randint(
@@ -96,7 +96,11 @@ def add_exploration_noise(
9696
if hasattr(batch.obs, "mask"):
9797
rand_act += batch.obs.mask
9898
act[rand_mask] = rand_act[rand_mask]
99-
return act
99+
return act
100+
else:
101+
raise NotImplementedError(
102+
f"Currently only numpy arrays are supported, got {type(act)=}."
103+
)
100104

101105

102106
class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy]):

tianshou/algorithm/modelfree/dqn.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def add_exploration_noise(
156156
batch: ObsBatchProtocol,
157157
) -> TArrOrActBatch:
158158
eps = self.eps_training if self.is_within_training_step else self.eps_inference
159-
# TODO: This looks problematic; the non-array case is silently ignored
160-
if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0):
159+
eps = self.eps_training if self.is_within_training_step else self.eps_inference
160+
if not np.isclose(eps, 0.0):
161+
return act
162+
if isinstance(act, np.ndarray):
161163
batch_size = len(act)
162164
rand_mask = np.random.rand(batch_size) < eps
163165
self.action_space = cast(Discrete, self.action_space) # for mypy
@@ -167,7 +169,10 @@ def add_exploration_noise(
167169
q += batch.obs.mask
168170
rand_act = q.argmax(axis=1)
169171
act[rand_mask] = rand_act[rand_mask]
170-
return act
172+
return act
173+
raise NotImplementedError(
174+
f"Currently only numpy array is supported for action, but got {type(act)}"
175+
)
171176

172177

173178
TDQNPolicy = TypeVar("TDQNPolicy", bound=DiscreteQLearningPolicy)

0 commit comments

Comments
 (0)