1212from overrides import override
1313from torch .distributions import Categorical , Distribution
1414
15+ from tianshou .config import ENABLE_VALIDATION
1516from tianshou .data import (
1617 Batch ,
1718 CachedReplayBuffer ,
@@ -318,8 +319,32 @@ def __init__(
318319 exploration_noise : bool = False ,
319320 # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737
320321 collect_stats_class : type [TCollectStats ] = CollectStats , # type: ignore[assignment]
321- raise_on_nan_in_buffer : bool = True ,
322+ raise_on_nan_in_buffer : bool = ENABLE_VALIDATION ,
322323 ) -> None :
324+ """
325+ :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch
326+ of actions from a batch of observations.
327+ :param env: a ``gymnasium.Env`` environment or a vectorized instance of the
328+ :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with
329+ a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv`
330+ will be constructed internally from the passed env)
331+ :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
332+ If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
333+ of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs)
334+ as the default buffer.
335+ :param exploration_noise: determine whether the action needs to be modified
336+ with the corresponding policy's exploration noise. If so, "policy.
337+ exploration_noise(act, batch)" will be called automatically to add the
338+ exploration noise into action.
339+ the rollout batch with this hook also modifies the data that is collected to the buffer!
340+ :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after
341+ a collection step. Especially useful when episode-level hooks are passed for making
342+ sure that nothing is broken during the collection. Consider setting to False if
343+ the NaN-check becomes a bottleneck.
344+ :param collect_stats_class: the class to use for collecting statistics. Allows customizing
345+ the stats collection logic by passing a subclass of :class:`CollectStats`. Changing
346+ this is rarely necessary and is mainly done by "power users".
347+ """
323348 if isinstance (env , gym .Env ) and not hasattr (env , "__len__" ):
324349 warnings .warn ("Single environment detected, wrap to DummyVectorEnv." )
325350 # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
@@ -557,7 +582,7 @@ def __init__(
557582 exploration_noise : bool = False ,
558583 on_episode_done_hook : Optional ["EpisodeRolloutHookProtocol" ] = None ,
559584 on_step_hook : Optional ["StepHookProtocol" ] = None ,
560- raise_on_nan_in_buffer : bool = True ,
585+ raise_on_nan_in_buffer : bool = ENABLE_VALIDATION ,
561586 collect_stats_class : type [TCollectStats ] = CollectStats , # type: ignore[assignment]
562587 ) -> None :
563588 """
@@ -574,7 +599,7 @@ def __init__(
574599 :param exploration_noise: determine whether the action needs to be modified
575600 with the corresponding policy's exploration noise. If so, "policy.
576601 exploration_noise(act, batch)" will be called automatically to add the
577- exploration noise into action..
602+ exploration noise into action.
578603 :param on_episode_done_hook: if passed will be executed when an episode is done.
579604 The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else).
580605 If a dict is returned by the hook it will be used to add new entries to the buffer
@@ -1045,7 +1070,7 @@ def _collect( # noqa: C901
10451070 break
10461071
10471072 # Check if we screwed up somewhere
1048- if self .buffer .hasnull ():
1073+ if self .raise_on_nan_in_buffer and self . buffer .hasnull ():
10491074 nan_batch = self .buffer .isnull ().apply_values_transform (np .sum )
10501075
10511076 raise MalformedBufferError (
0 commit comments