Skip to content

Commit 0a7e4ea

Browse files
committed
Merge remote-tracking branch 'thuml/dev-v2' into dev-v2
# Conflicts: # tianshou/utils/net/common.py
2 parents 2535c8c + 6684f0f commit 0a7e4ea

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

tianshou/utils/net/common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -674,17 +674,17 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput:
674674
def forward(
675675
self,
676676
obs: np.ndarray | torch.Tensor,
677-
state: T | None = None,
677+
rnn_hidden_state: T | None = None,
678678
info: dict[str, Any] | None = None,
679679
) -> tuple[np.ndarray | torch.Tensor, T | None]:
680680
"""
681681
The main method for tianshou to compute actions from env observations.
682682
Implementations will always make use of the preprocess_net as the first processing step.
683-
684-
:param obs: the observation to be passed to the actor.
685-
:param rnn_hidden_state: the hidden state of the RNN, if applicable.
686-
:param info: the info object from the env step
687-
:return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or
683+
684+
:param obs: the observation from the environment
685+
:param rnn_hidden_state: the hidden state of the RNN, if applicable
686+
:param info: the info object from the environment step
687+
:return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or
688688
a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution),
689689
and hidden_state is the new hidden state of the RNN, if applicable.
690690
"""
@@ -729,7 +729,7 @@ def is_discrete(self) -> bool:
729729
def forward(
730730
self,
731731
obs: np.ndarray | torch.Tensor | BatchProtocol,
732-
state: Any | None = None,
732+
rnn_hidden_state: Any | None = None,
733733
info: dict[str, Any] | None = None,
734734
) -> tuple[np.ndarray, Any | None]:
735735
batch_size = len(obs)
@@ -738,7 +738,7 @@ def forward(
738738
else:
739739
# Discrete Actors currently return an n-dimensional array of probabilities for each action
740740
action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n))
741-
return action, state
741+
return action, rnn_hidden_state
742742

743743
def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray:
744744
if self.is_discrete:

tianshou/utils/net/continuous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_output_dim(self) -> int:
6666
def forward(
6767
self,
6868
obs: np.ndarray | torch.Tensor,
69-
state: Any = None,
69+
rnn_hidden_state: Any = None,
7070
info: dict[str, Any] | None = None,
7171
) -> tuple[torch.Tensor, Any]:
7272
"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
@@ -76,7 +76,7 @@ def forward(
7676
The hidden state is only not None if a recurrent net is used as part of the
7777
learning algorithm (support for RNNs is currently experimental).
7878
"""
79-
action_BA, hidden_BH = self.preprocess(obs, state)
79+
action_BA, hidden_BH = self.preprocess(obs, rnn_hidden_state)
8080
action_BA = self.max_action * torch.tanh(self.last(action_BA))
8181
return action_BA, hidden_BH
8282

@@ -222,13 +222,13 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput:
222222
def forward(
223223
self,
224224
obs: np.ndarray | torch.Tensor,
225-
state: Any = None,
225+
rnn_hidden_state: Any = None,
226226
info: dict[str, Any] | None = None,
227227
) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
228228
"""Mapping: obs -> logits -> (mu, sigma)."""
229229
if info is None:
230230
info = {}
231-
logits, hidden = self.preprocess(obs, state)
231+
logits, hidden = self.preprocess(obs, rnn_hidden_state)
232232
mu = self.mu(logits)
233233
if not self._unbounded:
234234
mu = self.max_action * torch.tanh(mu)
@@ -238,7 +238,7 @@ def forward(
238238
shape = [1] * len(mu.shape)
239239
shape[1] = -1
240240
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
241-
return (mu, sigma), state
241+
return (mu, sigma), rnn_hidden_state
242242

243243

244244
class RecurrentActorProb(nn.Module):

tianshou/utils/net/discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput:
5959
def forward(
6060
self,
6161
obs: np.ndarray | torch.Tensor,
62-
state: Any = None,
62+
rnn_hidden_state: Any = None,
6363
info: dict[str, Any] | None = None,
6464
) -> tuple[torch.Tensor, torch.Tensor | None]:
6565
r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
@@ -71,7 +71,7 @@ def forward(
7171
The hidden state is only
7272
not None if a recurrent net is used as part of the learning algorithm.
7373
"""
74-
x, hidden_BH = self.preprocess(obs, state)
74+
x, hidden_BH = self.preprocess(obs, rnn_hidden_state)
7575
x = self.last(x)
7676
if self.softmax_output:
7777
x = F.softmax(x, dim=-1)

0 commit comments

Comments
 (0)