@@ -674,17 +674,17 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput:
674674 def forward (
675675 self ,
676676 obs : np .ndarray | torch .Tensor ,
677- state : T | None = None ,
677+ rnn_hidden_state : T | None = None ,
678678 info : dict [str , Any ] | None = None ,
679679 ) -> tuple [np .ndarray | torch .Tensor , T | None ]:
680680 """
681681 The main method for tianshou to compute actions from env observations.
682682 Implementations will always make use of the preprocess_net as the first processing step.
683-
684- :param obs: the observation to be passed to the actor.
685- :param rnn_hidden_state: the hidden state of the RNN, if applicable.
686- :param info: the info object from the env step
687- :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or
683+
684+ :param obs: the observation from the environment
685+ :param rnn_hidden_state: the hidden state of the RNN, if applicable
686+ :param info: the info object from the environment step
687+ :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or
688688 a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution),
689689 and hidden_state is the new hidden state of the RNN, if applicable.
690690 """
@@ -729,7 +729,7 @@ def is_discrete(self) -> bool:
729729 def forward (
730730 self ,
731731 obs : np .ndarray | torch .Tensor | BatchProtocol ,
732- state : Any | None = None ,
732+ rnn_hidden_state : Any | None = None ,
733733 info : dict [str , Any ] | None = None ,
734734 ) -> tuple [np .ndarray , Any | None ]:
735735 batch_size = len (obs )
@@ -738,7 +738,7 @@ def forward(
738738 else :
739739 # Discrete Actors currently return an n-dimensional array of probabilities for each action
740740 action = 1 / self .action_space .n * np .ones ((batch_size , self .action_space .n ))
741- return action , state
741+ return action , rnn_hidden_state
742742
743743 def compute_action_batch (self , obs : np .ndarray | torch .Tensor | BatchProtocol ) -> np .ndarray :
744744 if self .is_discrete :
0 commit comments