Skip to content

Commit 5c1e4b2

Browse files
committed
v1: disable buffer hasnull checks by default
Control validation enabling with global flag
1 parent 06fba02 commit 5c1e4b2

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

tianshou/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ENABLE_VALIDATION = False
2+
"""Validation can help catching bugs and issues but it slows down training and collection. Enable it only if needed."""

tianshou/data/collector.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from overrides import override
1313
from torch.distributions import Categorical, Distribution
1414

15+
from tianshou.config import ENABLE_VALIDATION
1516
from tianshou.data import (
1617
Batch,
1718
CachedReplayBuffer,
@@ -318,8 +319,32 @@ def __init__(
318319
exploration_noise: bool = False,
319320
# The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737
320321
collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment]
321-
raise_on_nan_in_buffer: bool = True,
322+
raise_on_nan_in_buffer: bool = ENABLE_VALIDATION,
322323
) -> None:
324+
"""
325+
:param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch
326+
of actions from a batch of observations.
327+
:param env: a ``gymnasium.Env`` environment or a vectorized instance of the
328+
:class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with
329+
a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv`
330+
will be constructed internally from the passed env)
331+
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
332+
If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
333+
of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs)
334+
as the default buffer.
335+
:param exploration_noise: determine whether the action needs to be modified
336+
with the corresponding policy's exploration noise. If so, "policy.
337+
exploration_noise(act, batch)" will be called automatically to add the
338+
exploration noise into action.
339+
the rollout batch with this hook also modifies the data that is collected to the buffer!
340+
:param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after
341+
a collection step. Especially useful when episode-level hooks are passed for making
342+
sure that nothing is broken during the collection. Consider setting to False if
343+
the NaN-check becomes a bottleneck.
344+
:param collect_stats_class: the class to use for collecting statistics. Allows customizing
345+
the stats collection logic by passing a subclass of :class:`CollectStats`. Changing
346+
this is rarely necessary and is mainly done by "power users".
347+
"""
323348
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
324349
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
325350
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
@@ -557,7 +582,7 @@ def __init__(
557582
exploration_noise: bool = False,
558583
on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None,
559584
on_step_hook: Optional["StepHookProtocol"] = None,
560-
raise_on_nan_in_buffer: bool = True,
585+
raise_on_nan_in_buffer: bool = ENABLE_VALIDATION,
561586
collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment]
562587
) -> None:
563588
"""
@@ -574,7 +599,7 @@ def __init__(
574599
:param exploration_noise: determine whether the action needs to be modified
575600
with the corresponding policy's exploration noise. If so, "policy.
576601
exploration_noise(act, batch)" will be called automatically to add the
577-
exploration noise into action..
602+
exploration noise into action.
578603
:param on_episode_done_hook: if passed will be executed when an episode is done.
579604
The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else).
580605
If a dict is returned by the hook it will be used to add new entries to the buffer
@@ -1045,7 +1070,7 @@ def _collect( # noqa: C901
10451070
break
10461071

10471072
# Check if we screwed up somewhere
1048-
if self.buffer.hasnull():
1073+
if self.raise_on_nan_in_buffer and self.buffer.hasnull():
10491074
nan_batch = self.buffer.isnull().apply_values_transform(np.sum)
10501075

10511076
raise MalformedBufferError(

tianshou/highlevel/env.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,9 @@ def _next_seed(rng: np.random.Generator) -> int:
389389
:param rng: the random number generator
390390
:return: the sampled random seed
391391
"""
392-
return int(rng.integers(-2**31, 2**31, dtype=np.int32)) # int32 is needed for envpool compatibility
392+
return int(
393+
rng.integers(-(2**31), 2**31, dtype=np.int32)
394+
) # int32 is needed for envpool compatibility
393395

394396
@abstractmethod
395397
def _create_env(self, mode: EnvMode) -> Env:

tianshou/trainer/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def _collect_training_data(self) -> CollectStats:
543543
lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes",
544544
)
545545

546-
if self.train_collector.buffer.hasnull():
546+
if self.train_collector.raise_on_nan_in_buffer and self.train_collector.buffer.hasnull():
547547
from tianshou.data.collector import EpisodeRolloutHook
548548
from tianshou.env import DummyVectorEnv
549549

0 commit comments

Comments
 (0)